Repository: kxhit/EscherNet Branch: main Commit: 10b650492ba9 Files: 768 Total size: 126.9 MB Directory structure: gitextract_grlqwy99/ ├── .gitignore ├── 3drecon/ │ ├── configs/ │ │ └── neus_36.yaml │ ├── ours_GSO_T1/ │ │ └── NeuS/ │ │ ├── grandmother/ │ │ │ ├── mesh.ply │ │ │ └── tensorboard_logs/ │ │ │ └── version_0/ │ │ │ ├── events.out.tfevents.1706660493.xin-PC.934205.0 │ │ │ └── hparams.yaml │ │ ├── lion/ │ │ │ ├── mesh.ply │ │ │ └── tensorboard_logs/ │ │ │ └── version_0/ │ │ │ ├── events.out.tfevents.1706659999.xin-PC.931603.0 │ │ │ └── hparams.yaml │ │ └── train/ │ │ └── tensorboard_logs/ │ │ └── version_0/ │ │ ├── events.out.tfevents.1706660976.xin-PC.936418.0 │ │ └── hparams.yaml │ ├── raymarching/ │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── raymarching.py │ │ ├── setup.py │ │ └── src/ │ │ ├── bindings.cpp │ │ ├── raymarching.cu │ │ └── raymarching.h │ ├── renderer/ │ │ ├── agg_net.py │ │ ├── cost_reg_net.py │ │ ├── dummy_dataset.py │ │ ├── feature_net.py │ │ ├── neus_networks.py │ │ ├── ngp_renderer.py │ │ └── renderer.py │ ├── run_NeuS.py │ ├── train_renderer.py │ └── util.py ├── 4DoF/ │ ├── CN_encoder.py │ ├── dataset.py │ ├── diffusers/ │ │ ├── __init__.py │ │ ├── commands/ │ │ │ ├── __init__.py │ │ │ ├── diffusers_cli.py │ │ │ └── env.py │ │ ├── configuration_utils.py │ │ ├── dependency_versions_check.py │ │ ├── dependency_versions_table.py │ │ ├── experimental/ │ │ │ ├── __init__.py │ │ │ └── rl/ │ │ │ ├── __init__.py │ │ │ └── value_guided_sampling.py │ │ ├── image_processor.py │ │ ├── loaders.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── attention.py │ │ │ ├── attention_flax.py │ │ │ ├── attention_processor.py │ │ │ ├── autoencoder_kl.py │ │ │ ├── controlnet.py │ │ │ ├── controlnet_flax.py │ │ │ ├── cross_attention.py │ │ │ ├── dual_transformer_2d.py │ │ │ ├── embeddings.py │ │ │ ├── embeddings_flax.py │ │ │ ├── modeling_flax_pytorch_utils.py │ │ │ ├── modeling_flax_utils.py │ │ │ ├── modeling_pytorch_flax_utils.py │ │ │ ├── modeling_utils.py │ │ │ ├── prior_transformer.py │ │ │ ├── resnet.py │ │ │ ├── resnet_flax.py │ │ │ ├── t5_film_transformer.py │ │ │ ├── transformer_2d.py │ │ │ ├── transformer_temporal.py │ │ │ ├── unet_1d.py │ │ │ ├── unet_1d_blocks.py │ │ │ ├── unet_2d.py │ │ │ ├── unet_2d_blocks.py │ │ │ ├── unet_2d_blocks_flax.py │ │ │ ├── unet_2d_condition.py │ │ │ ├── unet_2d_condition_flax.py │ │ │ ├── unet_3d_blocks.py │ │ │ ├── unet_3d_condition.py │ │ │ ├── vae.py │ │ │ ├── vae_flax.py │ │ │ └── vq_model.py │ │ ├── optimization.py │ │ ├── pipeline_utils.py │ │ ├── pipelines/ │ │ │ ├── __init__.py │ │ │ ├── alt_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── modeling_roberta_series.py │ │ │ │ ├── pipeline_alt_diffusion.py │ │ │ │ └── pipeline_alt_diffusion_img2img.py │ │ │ ├── audio_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── mel.py │ │ │ │ └── pipeline_audio_diffusion.py │ │ │ ├── audioldm/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_audioldm.py │ │ │ ├── consistency_models/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_consistency_models.py │ │ │ ├── controlnet/ │ │ │ │ ├── __init__.py │ │ │ │ ├── multicontrolnet.py │ │ │ │ ├── pipeline_controlnet.py │ │ │ │ ├── pipeline_controlnet_img2img.py │ │ │ │ ├── pipeline_controlnet_inpaint.py │ │ │ │ └── pipeline_flax_controlnet.py │ │ │ ├── dance_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_dance_diffusion.py │ │ │ ├── ddim/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_ddim.py │ │ │ ├── ddpm/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_ddpm.py │ │ │ ├── deepfloyd_if/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_if.py │ │ │ │ ├── pipeline_if_img2img.py │ │ │ │ ├── pipeline_if_img2img_superresolution.py │ │ │ │ ├── pipeline_if_inpainting.py │ │ │ │ ├── pipeline_if_inpainting_superresolution.py │ │ │ │ ├── pipeline_if_superresolution.py │ │ │ │ ├── safety_checker.py │ │ │ │ ├── timesteps.py │ │ │ │ └── watermark.py │ │ │ ├── dit/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_dit.py │ │ │ ├── kandinsky/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_kandinsky.py │ │ │ │ ├── pipeline_kandinsky_img2img.py │ │ │ │ ├── pipeline_kandinsky_inpaint.py │ │ │ │ ├── pipeline_kandinsky_prior.py │ │ │ │ └── text_encoder.py │ │ │ ├── kandinsky2_2/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_kandinsky2_2.py │ │ │ │ ├── pipeline_kandinsky2_2_controlnet.py │ │ │ │ ├── pipeline_kandinsky2_2_controlnet_img2img.py │ │ │ │ ├── pipeline_kandinsky2_2_img2img.py │ │ │ │ ├── pipeline_kandinsky2_2_inpainting.py │ │ │ │ ├── pipeline_kandinsky2_2_prior.py │ │ │ │ └── pipeline_kandinsky2_2_prior_emb2emb.py │ │ │ ├── latent_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_latent_diffusion.py │ │ │ │ └── pipeline_latent_diffusion_superresolution.py │ │ │ ├── latent_diffusion_uncond/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_latent_diffusion_uncond.py │ │ │ ├── onnx_utils.py │ │ │ ├── paint_by_example/ │ │ │ │ ├── __init__.py │ │ │ │ ├── image_encoder.py │ │ │ │ └── pipeline_paint_by_example.py │ │ │ ├── pipeline_flax_utils.py │ │ │ ├── pipeline_utils.py │ │ │ ├── pndm/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_pndm.py │ │ │ ├── repaint/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_repaint.py │ │ │ ├── score_sde_ve/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_score_sde_ve.py │ │ │ ├── semantic_stable_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_semantic_stable_diffusion.py │ │ │ ├── shap_e/ │ │ │ │ ├── __init__.py │ │ │ │ ├── camera.py │ │ │ │ ├── pipeline_shap_e.py │ │ │ │ ├── pipeline_shap_e_img2img.py │ │ │ │ └── renderer.py │ │ │ ├── spectrogram_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── continous_encoder.py │ │ │ │ ├── midi_utils.py │ │ │ │ ├── notes_encoder.py │ │ │ │ └── pipeline_spectrogram_diffusion.py │ │ │ ├── stable_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── convert_from_ckpt.py │ │ │ │ ├── pipeline_cycle_diffusion.py │ │ │ │ ├── pipeline_flax_stable_diffusion.py │ │ │ │ ├── pipeline_flax_stable_diffusion_controlnet.py │ │ │ │ ├── pipeline_flax_stable_diffusion_img2img.py │ │ │ │ ├── pipeline_flax_stable_diffusion_inpaint.py │ │ │ │ ├── pipeline_onnx_stable_diffusion.py │ │ │ │ ├── pipeline_onnx_stable_diffusion_img2img.py │ │ │ │ ├── pipeline_onnx_stable_diffusion_inpaint.py │ │ │ │ ├── pipeline_onnx_stable_diffusion_inpaint_legacy.py │ │ │ │ ├── pipeline_onnx_stable_diffusion_upscale.py │ │ │ │ ├── pipeline_stable_diffusion.py │ │ │ │ ├── pipeline_stable_diffusion_attend_and_excite.py │ │ │ │ ├── pipeline_stable_diffusion_controlnet.py │ │ │ │ ├── pipeline_stable_diffusion_depth2img.py │ │ │ │ ├── pipeline_stable_diffusion_diffedit.py │ │ │ │ ├── pipeline_stable_diffusion_image_variation.py │ │ │ │ ├── pipeline_stable_diffusion_img2img.py │ │ │ │ ├── pipeline_stable_diffusion_inpaint.py │ │ │ │ ├── pipeline_stable_diffusion_inpaint_legacy.py │ │ │ │ ├── pipeline_stable_diffusion_instruct_pix2pix.py │ │ │ │ ├── pipeline_stable_diffusion_k_diffusion.py │ │ │ │ ├── pipeline_stable_diffusion_latent_upscale.py │ │ │ │ ├── pipeline_stable_diffusion_ldm3d.py │ │ │ │ ├── pipeline_stable_diffusion_model_editing.py │ │ │ │ ├── pipeline_stable_diffusion_panorama.py │ │ │ │ ├── pipeline_stable_diffusion_paradigms.py │ │ │ │ ├── pipeline_stable_diffusion_pix2pix_zero.py │ │ │ │ ├── pipeline_stable_diffusion_sag.py │ │ │ │ ├── pipeline_stable_diffusion_upscale.py │ │ │ │ ├── pipeline_stable_unclip.py │ │ │ │ ├── pipeline_stable_unclip_img2img.py │ │ │ │ ├── safety_checker.py │ │ │ │ ├── safety_checker_flax.py │ │ │ │ └── stable_unclip_image_normalizer.py │ │ │ ├── stable_diffusion_safe/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_stable_diffusion_safe.py │ │ │ │ └── safety_checker.py │ │ │ ├── stable_diffusion_xl/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_stable_diffusion_xl.py │ │ │ │ ├── pipeline_stable_diffusion_xl_img2img.py │ │ │ │ └── watermark.py │ │ │ ├── stochastic_karras_ve/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_stochastic_karras_ve.py │ │ │ ├── text_to_video_synthesis/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_text_to_video_synth.py │ │ │ │ ├── pipeline_text_to_video_synth_img2img.py │ │ │ │ └── pipeline_text_to_video_zero.py │ │ │ ├── unclip/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_unclip.py │ │ │ │ ├── pipeline_unclip_image_variation.py │ │ │ │ └── text_proj.py │ │ │ ├── unidiffuser/ │ │ │ │ ├── __init__.py │ │ │ │ ├── modeling_text_decoder.py │ │ │ │ ├── modeling_uvit.py │ │ │ │ └── pipeline_unidiffuser.py │ │ │ ├── versatile_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── modeling_text_unet.py │ │ │ │ ├── pipeline_versatile_diffusion.py │ │ │ │ ├── pipeline_versatile_diffusion_dual_guided.py │ │ │ │ ├── pipeline_versatile_diffusion_image_variation.py │ │ │ │ └── pipeline_versatile_diffusion_text_to_image.py │ │ │ └── vq_diffusion/ │ │ │ ├── __init__.py │ │ │ └── pipeline_vq_diffusion.py │ │ ├── schedulers/ │ │ │ ├── __init__.py │ │ │ ├── scheduling_consistency_models.py │ │ │ ├── scheduling_ddim.py │ │ │ ├── scheduling_ddim_flax.py │ │ │ ├── scheduling_ddim_inverse.py │ │ │ ├── scheduling_ddim_parallel.py │ │ │ ├── scheduling_ddpm.py │ │ │ ├── scheduling_ddpm_flax.py │ │ │ ├── scheduling_ddpm_parallel.py │ │ │ ├── scheduling_deis_multistep.py │ │ │ ├── scheduling_dpmsolver_multistep.py │ │ │ ├── scheduling_dpmsolver_multistep_flax.py │ │ │ ├── scheduling_dpmsolver_multistep_inverse.py │ │ │ ├── scheduling_dpmsolver_sde.py │ │ │ ├── scheduling_dpmsolver_singlestep.py │ │ │ ├── scheduling_euler_ancestral_discrete.py │ │ │ ├── scheduling_euler_discrete.py │ │ │ ├── scheduling_heun_discrete.py │ │ │ ├── scheduling_ipndm.py │ │ │ ├── scheduling_k_dpm_2_ancestral_discrete.py │ │ │ ├── scheduling_k_dpm_2_discrete.py │ │ │ ├── scheduling_karras_ve.py │ │ │ ├── scheduling_karras_ve_flax.py │ │ │ ├── scheduling_lms_discrete.py │ │ │ ├── scheduling_lms_discrete_flax.py │ │ │ ├── scheduling_pndm.py │ │ │ ├── scheduling_pndm_flax.py │ │ │ ├── scheduling_repaint.py │ │ │ ├── scheduling_sde_ve.py │ │ │ ├── scheduling_sde_ve_flax.py │ │ │ ├── scheduling_sde_vp.py │ │ │ ├── scheduling_unclip.py │ │ │ ├── scheduling_unipc_multistep.py │ │ │ ├── scheduling_utils.py │ │ │ ├── scheduling_utils_flax.py │ │ │ └── scheduling_vq_diffusion.py │ │ ├── training_utils.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── accelerate_utils.py │ │ ├── constants.py │ │ ├── deprecation_utils.py │ │ ├── doc_utils.py │ │ ├── dummy_flax_and_transformers_objects.py │ │ ├── dummy_flax_objects.py │ │ ├── dummy_note_seq_objects.py │ │ ├── dummy_onnx_objects.py │ │ ├── dummy_pt_objects.py │ │ ├── dummy_torch_and_librosa_objects.py │ │ ├── dummy_torch_and_scipy_objects.py │ │ ├── dummy_torch_and_torchsde_objects.py │ │ ├── dummy_torch_and_transformers_and_invisible_watermark_objects.py │ │ ├── dummy_torch_and_transformers_and_k_diffusion_objects.py │ │ ├── dummy_torch_and_transformers_and_onnx_objects.py │ │ ├── dummy_torch_and_transformers_objects.py │ │ ├── dummy_transformers_and_torch_and_note_seq_objects.py │ │ ├── dynamic_modules_utils.py │ │ ├── hub_utils.py │ │ ├── import_utils.py │ │ ├── logging.py │ │ ├── model_card_template.md │ │ ├── outputs.py │ │ ├── pil_utils.py │ │ ├── testing_utils.py │ │ └── torch_utils.py │ ├── pipeline_zero1to3.py │ ├── train_eschernet.py │ └── unet_2d_condition.py ├── 6DoF/ │ ├── CN_encoder.py │ ├── dataset.py │ ├── diffusers/ │ │ ├── __init__.py │ │ ├── commands/ │ │ │ ├── __init__.py │ │ │ ├── diffusers_cli.py │ │ │ └── env.py │ │ ├── configuration_utils.py │ │ ├── dependency_versions_check.py │ │ ├── dependency_versions_table.py │ │ ├── experimental/ │ │ │ ├── __init__.py │ │ │ └── rl/ │ │ │ ├── __init__.py │ │ │ └── value_guided_sampling.py │ │ ├── image_processor.py │ │ ├── loaders.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── attention.py │ │ │ ├── attention_flax.py │ │ │ ├── attention_processor.py │ │ │ ├── autoencoder_kl.py │ │ │ ├── controlnet.py │ │ │ ├── controlnet_flax.py │ │ │ ├── cross_attention.py │ │ │ ├── dual_transformer_2d.py │ │ │ ├── embeddings.py │ │ │ ├── embeddings_flax.py │ │ │ ├── modeling_flax_pytorch_utils.py │ │ │ ├── modeling_flax_utils.py │ │ │ ├── modeling_pytorch_flax_utils.py │ │ │ ├── modeling_utils.py │ │ │ ├── prior_transformer.py │ │ │ ├── resnet.py │ │ │ ├── resnet_flax.py │ │ │ ├── t5_film_transformer.py │ │ │ ├── transformer_2d.py │ │ │ ├── transformer_temporal.py │ │ │ ├── unet_1d.py │ │ │ ├── unet_1d_blocks.py │ │ │ ├── unet_2d.py │ │ │ ├── unet_2d_blocks.py │ │ │ ├── unet_2d_blocks_flax.py │ │ │ ├── unet_2d_condition.py │ │ │ ├── unet_2d_condition_flax.py │ │ │ ├── unet_3d_blocks.py │ │ │ ├── unet_3d_condition.py │ │ │ ├── vae.py │ │ │ ├── vae_flax.py │ │ │ └── vq_model.py │ │ ├── optimization.py │ │ ├── pipeline_utils.py │ │ ├── pipelines/ │ │ │ ├── __init__.py │ │ │ ├── alt_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── modeling_roberta_series.py │ │ │ │ ├── pipeline_alt_diffusion.py │ │ │ │ └── pipeline_alt_diffusion_img2img.py │ │ │ ├── audio_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── mel.py │ │ │ │ └── pipeline_audio_diffusion.py │ │ │ ├── audioldm/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_audioldm.py │ │ │ ├── consistency_models/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_consistency_models.py │ │ │ ├── controlnet/ │ │ │ │ ├── __init__.py │ │ │ │ ├── multicontrolnet.py │ │ │ │ ├── pipeline_controlnet.py │ │ │ │ ├── pipeline_controlnet_img2img.py │ │ │ │ ├── pipeline_controlnet_inpaint.py │ │ │ │ └── pipeline_flax_controlnet.py │ │ │ ├── dance_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_dance_diffusion.py │ │ │ ├── ddim/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_ddim.py │ │ │ ├── ddpm/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_ddpm.py │ │ │ ├── deepfloyd_if/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_if.py │ │ │ │ ├── pipeline_if_img2img.py │ │ │ │ ├── pipeline_if_img2img_superresolution.py │ │ │ │ ├── pipeline_if_inpainting.py │ │ │ │ ├── pipeline_if_inpainting_superresolution.py │ │ │ │ ├── pipeline_if_superresolution.py │ │ │ │ ├── safety_checker.py │ │ │ │ ├── timesteps.py │ │ │ │ └── watermark.py │ │ │ ├── dit/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_dit.py │ │ │ ├── kandinsky/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_kandinsky.py │ │ │ │ ├── pipeline_kandinsky_img2img.py │ │ │ │ ├── pipeline_kandinsky_inpaint.py │ │ │ │ ├── pipeline_kandinsky_prior.py │ │ │ │ └── text_encoder.py │ │ │ ├── kandinsky2_2/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_kandinsky2_2.py │ │ │ │ ├── pipeline_kandinsky2_2_controlnet.py │ │ │ │ ├── pipeline_kandinsky2_2_controlnet_img2img.py │ │ │ │ ├── pipeline_kandinsky2_2_img2img.py │ │ │ │ ├── pipeline_kandinsky2_2_inpainting.py │ │ │ │ ├── pipeline_kandinsky2_2_prior.py │ │ │ │ └── pipeline_kandinsky2_2_prior_emb2emb.py │ │ │ ├── latent_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_latent_diffusion.py │ │ │ │ └── pipeline_latent_diffusion_superresolution.py │ │ │ ├── latent_diffusion_uncond/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_latent_diffusion_uncond.py │ │ │ ├── onnx_utils.py │ │ │ ├── paint_by_example/ │ │ │ │ ├── __init__.py │ │ │ │ ├── image_encoder.py │ │ │ │ └── pipeline_paint_by_example.py │ │ │ ├── pipeline_flax_utils.py │ │ │ ├── pipeline_utils.py │ │ │ ├── pndm/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_pndm.py │ │ │ ├── repaint/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_repaint.py │ │ │ ├── score_sde_ve/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_score_sde_ve.py │ │ │ ├── semantic_stable_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_semantic_stable_diffusion.py │ │ │ ├── shap_e/ │ │ │ │ ├── __init__.py │ │ │ │ ├── camera.py │ │ │ │ ├── pipeline_shap_e.py │ │ │ │ ├── pipeline_shap_e_img2img.py │ │ │ │ └── renderer.py │ │ │ ├── spectrogram_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── continous_encoder.py │ │ │ │ ├── midi_utils.py │ │ │ │ ├── notes_encoder.py │ │ │ │ └── pipeline_spectrogram_diffusion.py │ │ │ ├── stable_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── convert_from_ckpt.py │ │ │ │ ├── pipeline_cycle_diffusion.py │ │ │ │ ├── pipeline_flax_stable_diffusion.py │ │ │ │ ├── pipeline_flax_stable_diffusion_controlnet.py │ │ │ │ ├── pipeline_flax_stable_diffusion_img2img.py │ │ │ │ ├── pipeline_flax_stable_diffusion_inpaint.py │ │ │ │ ├── pipeline_onnx_stable_diffusion.py │ │ │ │ ├── pipeline_onnx_stable_diffusion_img2img.py │ │ │ │ ├── pipeline_onnx_stable_diffusion_inpaint.py │ │ │ │ ├── pipeline_onnx_stable_diffusion_inpaint_legacy.py │ │ │ │ ├── pipeline_onnx_stable_diffusion_upscale.py │ │ │ │ ├── pipeline_stable_diffusion.py │ │ │ │ ├── pipeline_stable_diffusion_attend_and_excite.py │ │ │ │ ├── pipeline_stable_diffusion_controlnet.py │ │ │ │ ├── pipeline_stable_diffusion_depth2img.py │ │ │ │ ├── pipeline_stable_diffusion_diffedit.py │ │ │ │ ├── pipeline_stable_diffusion_image_variation.py │ │ │ │ ├── pipeline_stable_diffusion_img2img.py │ │ │ │ ├── pipeline_stable_diffusion_inpaint.py │ │ │ │ ├── pipeline_stable_diffusion_inpaint_legacy.py │ │ │ │ ├── pipeline_stable_diffusion_instruct_pix2pix.py │ │ │ │ ├── pipeline_stable_diffusion_k_diffusion.py │ │ │ │ ├── pipeline_stable_diffusion_latent_upscale.py │ │ │ │ ├── pipeline_stable_diffusion_ldm3d.py │ │ │ │ ├── pipeline_stable_diffusion_model_editing.py │ │ │ │ ├── pipeline_stable_diffusion_panorama.py │ │ │ │ ├── pipeline_stable_diffusion_paradigms.py │ │ │ │ ├── pipeline_stable_diffusion_pix2pix_zero.py │ │ │ │ ├── pipeline_stable_diffusion_sag.py │ │ │ │ ├── pipeline_stable_diffusion_upscale.py │ │ │ │ ├── pipeline_stable_unclip.py │ │ │ │ ├── pipeline_stable_unclip_img2img.py │ │ │ │ ├── safety_checker.py │ │ │ │ ├── safety_checker_flax.py │ │ │ │ └── stable_unclip_image_normalizer.py │ │ │ ├── stable_diffusion_safe/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_stable_diffusion_safe.py │ │ │ │ └── safety_checker.py │ │ │ ├── stable_diffusion_xl/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_stable_diffusion_xl.py │ │ │ │ ├── pipeline_stable_diffusion_xl_img2img.py │ │ │ │ └── watermark.py │ │ │ ├── stochastic_karras_ve/ │ │ │ │ ├── __init__.py │ │ │ │ └── pipeline_stochastic_karras_ve.py │ │ │ ├── text_to_video_synthesis/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_text_to_video_synth.py │ │ │ │ ├── pipeline_text_to_video_synth_img2img.py │ │ │ │ └── pipeline_text_to_video_zero.py │ │ │ ├── unclip/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pipeline_unclip.py │ │ │ │ ├── pipeline_unclip_image_variation.py │ │ │ │ └── text_proj.py │ │ │ ├── unidiffuser/ │ │ │ │ ├── __init__.py │ │ │ │ ├── modeling_text_decoder.py │ │ │ │ ├── modeling_uvit.py │ │ │ │ └── pipeline_unidiffuser.py │ │ │ ├── versatile_diffusion/ │ │ │ │ ├── __init__.py │ │ │ │ ├── modeling_text_unet.py │ │ │ │ ├── pipeline_versatile_diffusion.py │ │ │ │ ├── pipeline_versatile_diffusion_dual_guided.py │ │ │ │ ├── pipeline_versatile_diffusion_image_variation.py │ │ │ │ └── pipeline_versatile_diffusion_text_to_image.py │ │ │ └── vq_diffusion/ │ │ │ ├── __init__.py │ │ │ └── pipeline_vq_diffusion.py │ │ ├── schedulers/ │ │ │ ├── __init__.py │ │ │ ├── scheduling_consistency_models.py │ │ │ ├── scheduling_ddim.py │ │ │ ├── scheduling_ddim_flax.py │ │ │ ├── scheduling_ddim_inverse.py │ │ │ ├── scheduling_ddim_parallel.py │ │ │ ├── scheduling_ddpm.py │ │ │ ├── scheduling_ddpm_flax.py │ │ │ ├── scheduling_ddpm_parallel.py │ │ │ ├── scheduling_deis_multistep.py │ │ │ ├── scheduling_dpmsolver_multistep.py │ │ │ ├── scheduling_dpmsolver_multistep_flax.py │ │ │ ├── scheduling_dpmsolver_multistep_inverse.py │ │ │ ├── scheduling_dpmsolver_sde.py │ │ │ ├── scheduling_dpmsolver_singlestep.py │ │ │ ├── scheduling_euler_ancestral_discrete.py │ │ │ ├── scheduling_euler_discrete.py │ │ │ ├── scheduling_heun_discrete.py │ │ │ ├── scheduling_ipndm.py │ │ │ ├── scheduling_k_dpm_2_ancestral_discrete.py │ │ │ ├── scheduling_k_dpm_2_discrete.py │ │ │ ├── scheduling_karras_ve.py │ │ │ ├── scheduling_karras_ve_flax.py │ │ │ ├── scheduling_lms_discrete.py │ │ │ ├── scheduling_lms_discrete_flax.py │ │ │ ├── scheduling_pndm.py │ │ │ ├── scheduling_pndm_flax.py │ │ │ ├── scheduling_repaint.py │ │ │ ├── scheduling_sde_ve.py │ │ │ ├── scheduling_sde_ve_flax.py │ │ │ ├── scheduling_sde_vp.py │ │ │ ├── scheduling_unclip.py │ │ │ ├── scheduling_unipc_multistep.py │ │ │ ├── scheduling_utils.py │ │ │ ├── scheduling_utils_flax.py │ │ │ └── scheduling_vq_diffusion.py │ │ ├── training_utils.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── accelerate_utils.py │ │ ├── constants.py │ │ ├── deprecation_utils.py │ │ ├── doc_utils.py │ │ ├── dummy_flax_and_transformers_objects.py │ │ ├── dummy_flax_objects.py │ │ ├── dummy_note_seq_objects.py │ │ ├── dummy_onnx_objects.py │ │ ├── dummy_pt_objects.py │ │ ├── dummy_torch_and_librosa_objects.py │ │ ├── dummy_torch_and_scipy_objects.py │ │ ├── dummy_torch_and_torchsde_objects.py │ │ ├── dummy_torch_and_transformers_and_invisible_watermark_objects.py │ │ ├── dummy_torch_and_transformers_and_k_diffusion_objects.py │ │ ├── dummy_torch_and_transformers_and_onnx_objects.py │ │ ├── dummy_torch_and_transformers_objects.py │ │ ├── dummy_transformers_and_torch_and_note_seq_objects.py │ │ ├── dynamic_modules_utils.py │ │ ├── hub_utils.py │ │ ├── import_utils.py │ │ ├── logging.py │ │ ├── model_card_template.md │ │ ├── outputs.py │ │ ├── pil_utils.py │ │ ├── testing_utils.py │ │ └── torch_utils.py │ ├── instructions.md │ ├── pipeline_zero1to3.py │ ├── train_eschernet.py │ └── unet_2d_condition.py ├── CaPE.py ├── LICENSE ├── README.md ├── demo/ │ └── GSO30/ │ └── elephant/ │ └── render_mvs_25/ │ └── model/ │ ├── 000.npy │ ├── 001.npy │ ├── 002.npy │ ├── 003.npy │ ├── 004.npy │ ├── 005.npy │ ├── 006.npy │ ├── 007.npy │ ├── 008.npy │ ├── 009.npy │ ├── 010.npy │ ├── 011.npy │ ├── 012.npy │ ├── 013.npy │ ├── 014.npy │ ├── 015.npy │ ├── 016.npy │ ├── 017.npy │ ├── 018.npy │ ├── 019.npy │ ├── 020.npy │ ├── 021.npy │ ├── 022.npy │ ├── 023.npy │ └── 024.npy ├── environment.yml ├── eval_eschernet.py ├── eval_eschernet.sh ├── metrics/ │ ├── NeRF_idx/ │ │ ├── chair/ │ │ │ ├── test_M20.npy │ │ │ ├── train_N100M20_random.npy │ │ │ ├── train_N10M20_random.npy │ │ │ ├── train_N1M20_random.npy │ │ │ ├── train_N20M20_random.npy │ │ │ ├── train_N2M20_random.npy │ │ │ ├── train_N30M20_random.npy │ │ │ ├── train_N3M20_random.npy │ │ │ ├── train_N50M20_random.npy │ │ │ ├── train_N5M20_random.npy │ │ │ ├── train_fov0M20_fov.npy │ │ │ ├── train_fov135M20_fov.npy │ │ │ ├── train_fov15M20_fov.npy │ │ │ ├── train_fov180M20_fov.npy │ │ │ ├── train_fov225M20_fov.npy │ │ │ ├── train_fov270M20_fov.npy │ │ │ ├── train_fov30M20_fov.npy │ │ │ ├── train_fov360M20_fov.npy │ │ │ ├── train_fov50M20_fov.npy │ │ │ ├── train_fov5M20_fov.npy │ │ │ └── train_fov90M20_fov.npy │ │ ├── drums/ │ │ │ ├── test_M20.npy │ │ │ ├── train_N100M20_random.npy │ │ │ ├── train_N10M20_random.npy │ │ │ ├── train_N1M20_random.npy │ │ │ ├── train_N20M20_random.npy │ │ │ ├── train_N2M20_random.npy │ │ │ ├── train_N30M20_random.npy │ │ │ ├── train_N3M20_random.npy │ │ │ ├── train_N50M20_random.npy │ │ │ ├── train_N5M20_random.npy │ │ │ ├── train_fov0M20_fov.npy │ │ │ ├── train_fov135M20_fov.npy │ │ │ ├── train_fov15M20_fov.npy │ │ │ ├── train_fov180M20_fov.npy │ │ │ ├── train_fov225M20_fov.npy │ │ │ ├── train_fov270M20_fov.npy │ │ │ ├── train_fov30M20_fov.npy │ │ │ ├── train_fov360M20_fov.npy │ │ │ ├── train_fov50M20_fov.npy │ │ │ ├── train_fov5M20_fov.npy │ │ │ └── train_fov90M20_fov.npy │ │ ├── ficus/ │ │ │ ├── test_M20.npy │ │ │ ├── train_N100M20_random.npy │ │ │ ├── train_N10M20_random.npy │ │ │ ├── train_N1M20_random.npy │ │ │ ├── train_N20M20_random.npy │ │ │ ├── train_N2M20_random.npy │ │ │ ├── train_N30M20_random.npy │ │ │ ├── train_N3M20_random.npy │ │ │ ├── train_N50M20_random.npy │ │ │ ├── train_N5M20_random.npy │ │ │ ├── train_fov0M20_fov.npy │ │ │ ├── train_fov135M20_fov.npy │ │ │ ├── train_fov15M20_fov.npy │ │ │ ├── train_fov180M20_fov.npy │ │ │ ├── train_fov225M20_fov.npy │ │ │ ├── train_fov270M20_fov.npy │ │ │ ├── train_fov30M20_fov.npy │ │ │ ├── train_fov360M20_fov.npy │ │ │ ├── train_fov50M20_fov.npy │ │ │ ├── train_fov5M20_fov.npy │ │ │ └── train_fov90M20_fov.npy │ │ ├── hotdog/ │ │ │ ├── test_M20.npy │ │ │ ├── train_N100M20_random.npy │ │ │ ├── train_N10M20_random.npy │ │ │ ├── train_N1M20_random.npy │ │ │ ├── train_N20M20_random.npy │ │ │ ├── train_N2M20_random.npy │ │ │ ├── train_N30M20_random.npy │ │ │ ├── train_N3M20_random.npy │ │ │ ├── train_N50M20_random.npy │ │ │ ├── train_N5M20_random.npy │ │ │ ├── train_fov0M20_fov.npy │ │ │ ├── train_fov135M20_fov.npy │ │ │ ├── train_fov15M20_fov.npy │ │ │ ├── train_fov180M20_fov.npy │ │ │ ├── train_fov225M20_fov.npy │ │ │ ├── train_fov270M20_fov.npy │ │ │ ├── train_fov30M20_fov.npy │ │ │ ├── train_fov360M20_fov.npy │ │ │ ├── train_fov50M20_fov.npy │ │ │ ├── train_fov5M20_fov.npy │ │ │ └── train_fov90M20_fov.npy │ │ ├── lego/ │ │ │ ├── test_M20.npy │ │ │ ├── train_N100M20_random.npy │ │ │ ├── train_N10M20_random.npy │ │ │ ├── train_N1M20_random.npy │ │ │ ├── train_N20M20_random.npy │ │ │ ├── train_N2M20_random.npy │ │ │ ├── train_N30M20_random.npy │ │ │ ├── train_N3M20_random.npy │ │ │ ├── train_N50M20_random.npy │ │ │ ├── train_N5M20_random.npy │ │ │ ├── train_fov0M20_fov.npy │ │ │ ├── train_fov135M20_fov.npy │ │ │ ├── train_fov15M20_fov.npy │ │ │ ├── train_fov180M20_fov.npy │ │ │ ├── train_fov225M20_fov.npy │ │ │ ├── train_fov270M20_fov.npy │ │ │ ├── train_fov30M20_fov.npy │ │ │ ├── train_fov360M20_fov.npy │ │ │ ├── train_fov50M20_fov.npy │ │ │ ├── train_fov5M20_fov.npy │ │ │ └── train_fov90M20_fov.npy │ │ ├── materials/ │ │ │ ├── test_M20.npy │ │ │ ├── train_N100M20_random.npy │ │ │ ├── train_N10M20_random.npy │ │ │ ├── train_N1M20_random.npy │ │ │ ├── train_N20M20_random.npy │ │ │ ├── train_N2M20_random.npy │ │ │ ├── train_N30M20_random.npy │ │ │ ├── train_N3M20_random.npy │ │ │ ├── train_N50M20_random.npy │ │ │ ├── train_N5M20_random.npy │ │ │ ├── train_fov0M20_fov.npy │ │ │ ├── train_fov135M20_fov.npy │ │ │ ├── train_fov15M20_fov.npy │ │ │ ├── train_fov180M20_fov.npy │ │ │ ├── train_fov225M20_fov.npy │ │ │ ├── train_fov270M20_fov.npy │ │ │ ├── train_fov30M20_fov.npy │ │ │ ├── train_fov360M20_fov.npy │ │ │ ├── train_fov50M20_fov.npy │ │ │ ├── train_fov5M20_fov.npy │ │ │ └── train_fov90M20_fov.npy │ │ ├── mic/ │ │ │ ├── test_M20.npy │ │ │ ├── train_N100M20_random.npy │ │ │ ├── train_N10M20_random.npy │ │ │ ├── train_N1M20_random.npy │ │ │ ├── train_N20M20_random.npy │ │ │ ├── train_N2M20_random.npy │ │ │ ├── train_N30M20_random.npy │ │ │ ├── train_N3M20_random.npy │ │ │ ├── train_N50M20_random.npy │ │ │ ├── train_N5M20_random.npy │ │ │ ├── train_fov0M20_fov.npy │ │ │ ├── train_fov135M20_fov.npy │ │ │ ├── train_fov15M20_fov.npy │ │ │ ├── train_fov180M20_fov.npy │ │ │ ├── train_fov225M20_fov.npy │ │ │ ├── train_fov270M20_fov.npy │ │ │ ├── train_fov30M20_fov.npy │ │ │ ├── train_fov360M20_fov.npy │ │ │ ├── train_fov50M20_fov.npy │ │ │ ├── train_fov5M20_fov.npy │ │ │ └── train_fov90M20_fov.npy │ │ └── ship/ │ │ ├── test_M20.npy │ │ ├── train_N100M20_random.npy │ │ ├── train_N10M20_random.npy │ │ ├── train_N1M20_random.npy │ │ ├── train_N20M20_random.npy │ │ ├── train_N2M20_random.npy │ │ ├── train_N30M20_random.npy │ │ ├── train_N3M20_random.npy │ │ ├── train_N50M20_random.npy │ │ ├── train_N5M20_random.npy │ │ ├── train_fov0M20_fov.npy │ │ ├── train_fov135M20_fov.npy │ │ ├── train_fov15M20_fov.npy │ │ ├── train_fov180M20_fov.npy │ │ ├── train_fov225M20_fov.npy │ │ ├── train_fov270M20_fov.npy │ │ ├── train_fov30M20_fov.npy │ │ ├── train_fov360M20_fov.npy │ │ ├── train_fov50M20_fov.npy │ │ ├── train_fov5M20_fov.npy │ │ └── train_fov90M20_fov.npy │ ├── RenderOption_phong.json │ ├── RenderOption_rgb.json │ ├── ScreenCamera_0.json │ ├── ScreenCamera_1.json │ ├── ScreenCamera_2.json │ ├── ScreenCamera_3.json │ ├── ScreenCamera_4.json │ ├── ScreenCamera_5.json │ ├── eval_2D_NVS.py │ ├── eval_3D_GSO.py │ └── metrics.py └── scripts/ ├── README.md ├── all_invalid.npy ├── blender_script_mvs.py ├── empty_ids.npy ├── invalid_ids.npy ├── obj_ids.npy ├── objaverse_filter.py └── render_all_mvs.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea/ training/ lightning_logs/ image_log/ *.pth *.pt *.ckpt *.safetensors gradio_pose2image_private.py gradio_canny2image_private.py # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # log files log_bigboy logs* wandb # pretrained models pretrained ================================================ FILE: 3drecon/configs/neus_36.yaml ================================================ model: base_lr: 5.0e-4 target: renderer.renderer.RendererTrainer params: total_steps: 2000 warm_up_steps: 100 train_batch_num: 2560 train_batch_fg_num: 512 test_batch_num: 4096 use_mask: true lambda_rgb_loss: 0.5 lambda_mask_loss: 1.0 lambda_eikonal_loss: 0.1 use_warm_up: true data: target: renderer.dummy_dataset.DummyDataset params: {} callbacks: save_interval: 500 trainer: val_check_interval: 500 max_steps: 2000 ================================================ FILE: 3drecon/ours_GSO_T1/NeuS/grandmother/mesh.ply ================================================ [File too large to display: 10.4 MB] ================================================ FILE: 3drecon/ours_GSO_T1/NeuS/grandmother/tensorboard_logs/version_0/hparams.yaml ================================================ {} ================================================ FILE: 3drecon/ours_GSO_T1/NeuS/lion/mesh.ply ================================================ [File too large to display: 10.3 MB] ================================================ FILE: 3drecon/ours_GSO_T1/NeuS/lion/tensorboard_logs/version_0/hparams.yaml ================================================ {} ================================================ FILE: 3drecon/ours_GSO_T1/NeuS/train/tensorboard_logs/version_0/hparams.yaml ================================================ {} ================================================ FILE: 3drecon/raymarching/__init__.py ================================================ from .raymarching import * ================================================ FILE: 3drecon/raymarching/backend.py ================================================ import os from torch.utils.cpp_extension import load _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ '-O3', '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', ] if os.name == "posix": c_flags = ['-O3', '-std=c++14'] elif os.name == "nt": c_flags = ['/O2', '/std:c++17'] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") os.environ["PATH"] += ";" + cl_path _backend = load(name='_raymarching', extra_cflags=c_flags, extra_cuda_cflags=nvcc_flags, sources=[os.path.join(_src_path, 'src', f) for f in [ 'raymarching.cu', 'bindings.cpp', ]], ) __all__ = ['_backend'] ================================================ FILE: 3drecon/raymarching/raymarching.py ================================================ import numpy as np import time import torch import torch.nn as nn from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd try: import _raymarching as _backend except ImportError: from .backend import _backend # ---------------------------------------- # utils # ---------------------------------------- class _near_far_from_aabb(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): ''' near_far_from_aabb, CUDA implementation Calculate rays' intersection time (near and far) with aabb Args: rays_o: float, [N, 3] rays_d: float, [N, 3] aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) min_near: float, scalar Returns: nears: float, [N] fars: float, [N] ''' if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # num rays nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) return nears, fars near_far_from_aabb = _near_far_from_aabb.apply class _sph_from_ray(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, rays_o, rays_d, radius): ''' sph_from_ray, CUDA implementation get spherical coordinate on the background sphere from rays. Assume rays_o are inside the Sphere(radius). Args: rays_o: [N, 3] rays_d: [N, 3] radius: scalar, float Return: coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) ''' if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # num rays coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) _backend.sph_from_ray(rays_o, rays_d, radius, N, coords) return coords sph_from_ray = _sph_from_ray.apply class _morton3D(Function): @staticmethod def forward(ctx, coords): ''' morton3D, CUDA implementation Args: coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) TODO: check if the coord range is valid! (current 128 is safe) Returns: indices: [N], int32, in [0, 128^3) ''' if not coords.is_cuda: coords = coords.cuda() N = coords.shape[0] indices = torch.empty(N, dtype=torch.int32, device=coords.device) _backend.morton3D(coords.int(), N, indices) return indices morton3D = _morton3D.apply class _morton3D_invert(Function): @staticmethod def forward(ctx, indices): ''' morton3D_invert, CUDA implementation Args: indices: [N], int32, in [0, 128^3) Returns: coords: [N, 3], int32, in [0, 128) ''' if not indices.is_cuda: indices = indices.cuda() N = indices.shape[0] coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) _backend.morton3D_invert(indices.int(), N, coords) return coords morton3D_invert = _morton3D_invert.apply class _packbits(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, grid, thresh, bitfield=None): ''' packbits, CUDA implementation Pack up the density grid into a bit field to accelerate ray marching. Args: grid: float, [C, H * H * H], assume H % 2 == 0 thresh: float, threshold Returns: bitfield: uint8, [C, H * H * H / 8] ''' if not grid.is_cuda: grid = grid.cuda() grid = grid.contiguous() C = grid.shape[0] H3 = grid.shape[1] N = C * H3 // 8 if bitfield is None: bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) _backend.packbits(grid, N, thresh, bitfield) return bitfield packbits = _packbits.apply # ---------------------------------------- # train functions # ---------------------------------------- class _march_rays_train(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024): ''' march rays to generate points (forward only) Args: rays_o/d: float, [N, 3] bound: float, scalar density_bitfield: uint8: [CHHH // 8] C: int H: int nears/fars: float, [N] step_counter: int32, (2), used to count the actual number of generated points. mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) perturb: bool align: int, pad output so its size is dividable by align, set to -1 to disable. force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) max_steps: int, max number of sampled points along each ray, also affect min_stepsize. Returns: xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) dirs: float, [M, 3], all generated points' view dirs. deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth) rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0] ''' if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) density_bitfield = density_bitfield.contiguous() N = rays_o.shape[0] # num rays M = N * max_steps # init max points number in total # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. if not force_all_rays and mean_count > 0: if align > 0: mean_count += align - mean_count % align M = mean_count xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps if step_counter is None: step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter if perturb: noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) else: noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number #print(step_counter, M) # only used at the first (few) epochs. if force_all_rays or mean_count <= 0: m = step_counter[0].item() # D2H copy if align > 0: m += align - m % align xyzs = xyzs[:m] dirs = dirs[:m] deltas = deltas[:m] torch.cuda.empty_cache() return xyzs, dirs, deltas, rays march_rays_train = _march_rays_train.apply class _composite_rays_train(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4): ''' composite rays' rgbs, according to the ray marching formula. Args: rgbs: float, [M, 3] sigmas: float, [M,] deltas: float, [M, 2] rays: int32, [N, 3] Returns: weights_sum: float, [N,], the alpha channel depth: float, [N, ], the Depth image: float, [N, 3], the RGB channel (after multiplying alpha!) ''' sigmas = sigmas.contiguous() rgbs = rgbs.contiguous() M = sigmas.shape[0] N = rays.shape[0] weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) _backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image) ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image) ctx.dims = [M, N, T_thresh] return weights_sum, depth, image @staticmethod @custom_bwd def backward(ctx, grad_weights_sum, grad_depth, grad_image): # NOTE: grad_depth is not used now! It won't be propagated to sigmas. grad_weights_sum = grad_weights_sum.contiguous() grad_image = grad_image.contiguous() sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors M, N, T_thresh = ctx.dims grad_sigmas = torch.zeros_like(sigmas) grad_rgbs = torch.zeros_like(rgbs) _backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs) return grad_sigmas, grad_rgbs, None, None, None composite_rays_train = _composite_rays_train.apply # ---------------------------------------- # infer functions # ---------------------------------------- class _march_rays(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024): ''' march rays to generate points (forward only, for inference) Args: n_alive: int, number of alive rays n_step: int, how many steps we march rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) rays_t: float, [N], the alive rays' time, we only use the first n_alive. rays_o/d: float, [N, 3] bound: float, scalar density_bitfield: uint8: [CHHH // 8] C: int H: int nears/fars: float, [N] align: int, pad output so its size is dividable by align, set to -1 to disable. perturb: bool/int, int > 0 is used as the random seed. dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) max_steps: int, max number of sampled points along each ray, also affect min_stepsize. Returns: xyzs: float, [n_alive * n_step, 3], all generated points' coords dirs: float, [n_alive * n_step, 3], all generated points' view dirs. deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). ''' if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) M = n_alive * n_step if align > 0: M += align - (M % align) xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth if perturb: # torch.manual_seed(perturb) # test_gui uses spp index as seed noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) else: noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises) return xyzs, dirs, deltas march_rays = _march_rays.apply class _composite_rays(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2): ''' composite rays' rgbs, according to the ray marching formula. (for inference) Args: n_alive: int, number of alive rays n_step: int, how many steps we march rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) rays_t: float, [N], the alive rays' time sigmas: float, [n_alive * n_step,] rgbs: float, [n_alive * n_step, 3] deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). In-place Outputs: weights_sum: float, [N,], the alpha channel depth: float, [N,], the depth value image: float, [N, 3], the RGB channel (after multiplying alpha!) ''' _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image) return tuple() composite_rays = _composite_rays.apply ================================================ FILE: 3drecon/raymarching/setup.py ================================================ import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ '-O3', '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', ] if os.name == "posix": c_flags = ['-O3', '-std=c++14'] elif os.name == "nt": c_flags = ['/O2', '/std:c++17'] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") os.environ["PATH"] += ";" + cl_path ''' Usage: python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) python setup.py install # build extensions and install (copy) to PATH. pip install . # ditto but better (e.g., dependency & metadata handling) python setup.py develop # build extensions and install (symbolic) to PATH. pip install -e . # ditto but better (e.g., dependency & metadata handling) ''' setup( name='raymarching', # package name, import this to use python API ext_modules=[ CUDAExtension( name='_raymarching', # extension name, import this to use CUDA API sources=[os.path.join(_src_path, 'src', f) for f in [ 'raymarching.cu', 'bindings.cpp', ]], extra_compile_args={ 'cxx': c_flags, 'nvcc': nvcc_flags, } ), ], cmdclass={ 'build_ext': BuildExtension, } ) ================================================ FILE: 3drecon/raymarching/src/bindings.cpp ================================================ #include #include "raymarching.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // utils m.def("packbits", &packbits, "packbits (CUDA)"); m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); m.def("morton3D", &morton3D, "morton3D (CUDA)"); m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); // train m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); // infer m.def("march_rays", &march_rays, "march rays (CUDA)"); m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); } ================================================ FILE: 3drecon/raymarching/src/raymarching.cu ================================================ #include #include #include #include #include #include #include #include #include #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } inline constexpr __device__ float PI() { return 3.141592653589793f; } inline constexpr __device__ float RPI() { return 0.3183098861837907f; } template inline __host__ __device__ T div_round_up(T val, T divisor) { return (val + divisor - 1) / divisor; } inline __host__ __device__ float signf(const float x) { return copysignf(1.0, x); } inline __host__ __device__ float clamp(const float x, const float min, const float max) { return fminf(max, fmaxf(min, x)); } inline __host__ __device__ void swapf(float& a, float& b) { float c = a; a = b; b = c; } inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); int exponent; frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ... return fminf(max_cascade - 1, fmaxf(0, exponent)); } inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { const float mx = dt * H * 0.5; int exponent; frexpf(mx, &exponent); return fminf(max_cascade - 1, fmaxf(0, exponent)); } inline __host__ __device__ uint32_t __expand_bits(uint32_t v) { v = (v * 0x00010001u) & 0xFF0000FFu; v = (v * 0x00000101u) & 0x0F00F00Fu; v = (v * 0x00000011u) & 0xC30C30C3u; v = (v * 0x00000005u) & 0x49249249u; return v; } inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) { uint32_t xx = __expand_bits(x); uint32_t yy = __expand_bits(y); uint32_t zz = __expand_bits(z); return xx | (yy << 1) | (zz << 2); } inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) { x = x & 0x49249249; x = (x | (x >> 2)) & 0xc30c30c3; x = (x | (x >> 4)) & 0x0f00f00f; x = (x | (x >> 8)) & 0xff0000ff; x = (x | (x >> 16)) & 0x0000ffff; return x; } //////////////////////////////////////////////////// ///////////// utils ///////////// //////////////////////////////////////////////////// // rays_o/d: [N, 3] // nears/fars: [N] // scalar_t should always be float in use. template __global__ void kernel_near_far_from_aabb( const scalar_t * __restrict__ rays_o, const scalar_t * __restrict__ rays_d, const scalar_t * __restrict__ aabb, const uint32_t N, const float min_near, scalar_t * nears, scalar_t * fars ) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate rays_o += n * 3; rays_d += n * 3; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; // get near far (assume cube scene) float near = (aabb[0] - ox) * rdx; float far = (aabb[3] - ox) * rdx; if (near > far) swapf(near, far); float near_y = (aabb[1] - oy) * rdy; float far_y = (aabb[4] - oy) * rdy; if (near_y > far_y) swapf(near_y, far_y); if (near > far_y || near_y > far) { nears[n] = fars[n] = std::numeric_limits::max(); return; } if (near_y > near) near = near_y; if (far_y < far) far = far_y; float near_z = (aabb[2] - oz) * rdz; float far_z = (aabb[5] - oz) * rdz; if (near_z > far_z) swapf(near_z, far_z); if (near > far_z || near_z > far) { nears[n] = fars[n] = std::numeric_limits::max(); return; } if (near_z > near) near = near_z; if (far_z < far) far = far_z; if (near < min_near) near = min_near; nears[n] = near; fars[n] = far; } void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "near_far_from_aabb", ([&] { kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr()); })); } // rays_o/d: [N, 3] // radius: float // coords: [N, 2] template __global__ void kernel_sph_from_ray( const scalar_t * __restrict__ rays_o, const scalar_t * __restrict__ rays_d, const float radius, const uint32_t N, scalar_t * coords ) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate rays_o += n * 3; rays_d += n * 3; coords += n * 2; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; // solve t from || o + td || = radius const float A = dx * dx + dy * dy + dz * dz; const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2 const float C = ox * ox + oy * oy + oz * oz - radius * radius; const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive) // solve theta, phi (assume y is the up axis) const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI) const float phi = atan2(z, x); // [-PI, PI) // normalize to [-1, 1] coords[0] = 2 * theta * RPI() - 1; coords[1] = phi * RPI(); } void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "sph_from_ray", ([&] { kernel_sph_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr()); })); } // coords: int32, [N, 3] // indices: int32, [N] __global__ void kernel_morton3D( const int * __restrict__ coords, const uint32_t N, int * indices ) { // parallel const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate coords += n * 3; indices[n] = __morton3D(coords[0], coords[1], coords[2]); } void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) { static constexpr uint32_t N_THREAD = 128; kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr()); } // indices: int32, [N] // coords: int32, [N, 3] __global__ void kernel_morton3D_invert( const int * __restrict__ indices, const uint32_t N, int * coords ) { // parallel const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate coords += n * 3; const int ind = indices[n]; coords[0] = __morton3D_invert(ind >> 0); coords[1] = __morton3D_invert(ind >> 1); coords[2] = __morton3D_invert(ind >> 2); } void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) { static constexpr uint32_t N_THREAD = 128; kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr()); } // grid: float, [C, H, H, H] // N: int, C * H * H * H / 8 // density_thresh: float // bitfield: uint8, [N] template __global__ void kernel_packbits( const scalar_t * __restrict__ grid, const uint32_t N, const float density_thresh, uint8_t * bitfield ) { // parallel per byte const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate grid += n * 8; uint8_t bits = 0; #pragma unroll for (uint8_t i = 0; i < 8; i++) { bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; } bitfield[n] = bits; } void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( grid.scalar_type(), "packbits", ([&] { kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr()); })); } //////////////////////////////////////////////////// ///////////// training ///////////// //////////////////////////////////////////////////// // rays_o/d: [N, 3] // grid: [CHHH / 8] // xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2] // dirs: [M, 3] // rays: [N, 3], idx, offset, num_steps template __global__ void kernel_march_rays_train( const scalar_t * __restrict__ rays_o, const scalar_t * __restrict__ rays_d, const uint8_t * __restrict__ grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const scalar_t* __restrict__ nears, const scalar_t* __restrict__ fars, scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, int * rays, int * counter, const scalar_t* __restrict__ noises ) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate rays_o += n * 3; rays_d += n * 3; // ray marching const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; const float rH = 1 / (float)H; const float H3 = H * H * H; const float near = nears[n]; const float far = fars[n]; const float noise = noises[n]; const float dt_min = 2 * SQRT3() / max_steps; const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; float t0 = near; // perturb t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise; // first pass: estimation of num_steps float t = t0; uint32_t num_steps = 0; //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far); while (t < far && num_steps < max_steps) { // current point const float x = clamp(ox + t * dx, -bound, bound); const float y = clamp(oy + t * dy, -bound, bound); const float z = clamp(oz + t * dz, -bound, bound); const float dt = clamp(t * dt_gamma, dt_min, dt_max); // get mip level const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] const float mip_bound = fminf(scalbnf(1.0f, level), bound); const float mip_rbound = 1 / mip_bound; // convert to nearest grid position const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const uint32_t index = level * H3 + __morton3D(nx, ny, nz); const bool occ = grid[index / 8] & (1 << (index % 8)); // if occpuied, advance a small step, and write to output //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps); if (occ) { num_steps++; t += dt; // else, skip a large step (basically skip a voxel grid) } else { // calc distance to next voxel const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); // step until next voxel do { t += clamp(t * dt_gamma, dt_min, dt_max); } while (t < tt); } } //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min); // second pass: really locate and write points & dirs uint32_t point_index = atomicAdd(counter, num_steps); uint32_t ray_index = atomicAdd(counter + 1, 1); //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index); // write rays rays[ray_index * 3] = n; rays[ray_index * 3 + 1] = point_index; rays[ray_index * 3 + 2] = num_steps; if (num_steps == 0) return; if (point_index + num_steps > M) return; xyzs += point_index * 3; dirs += point_index * 3; deltas += point_index * 2; t = t0; uint32_t step = 0; float last_t = t; while (t < far && step < num_steps) { // current point const float x = clamp(ox + t * dx, -bound, bound); const float y = clamp(oy + t * dy, -bound, bound); const float z = clamp(oz + t * dz, -bound, bound); const float dt = clamp(t * dt_gamma, dt_min, dt_max); // get mip level const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] const float mip_bound = fminf(scalbnf(1.0f, level), bound); const float mip_rbound = 1 / mip_bound; // convert to nearest grid position const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); // query grid const uint32_t index = level * H3 + __morton3D(nx, ny, nz); const bool occ = grid[index / 8] & (1 << (index % 8)); // if occpuied, advance a small step, and write to output if (occ) { // write step xyzs[0] = x; xyzs[1] = y; xyzs[2] = z; dirs[0] = dx; dirs[1] = dy; dirs[2] = dz; t += dt; deltas[0] = dt; deltas[1] = t - last_t; // used to calc depth last_t = t; xyzs += 3; dirs += 3; deltas += 2; step++; // else, skip a large step (basically skip a voxel grid) } else { // calc distance to next voxel const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); // step until next voxel do { t += clamp(t * dt_gamma, dt_min, dt_max); } while (t < tt); } } } void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "march_rays_train", ([&] { kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr(), fars.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), noises.data_ptr()); })); } // sigmas: [M] // rgbs: [M, 3] // deltas: [M, 2] // rays: [N, 3], idx, offset, num_steps // weights_sum: [N], final pixel alpha // depth: [N,] // image: [N, 3] template __global__ void kernel_composite_rays_train_forward( const scalar_t * __restrict__ sigmas, const scalar_t * __restrict__ rgbs, const scalar_t * __restrict__ deltas, const int * __restrict__ rays, const uint32_t M, const uint32_t N, const float T_thresh, scalar_t * weights_sum, scalar_t * depth, scalar_t * image ) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate uint32_t index = rays[n * 3]; uint32_t offset = rays[n * 3 + 1]; uint32_t num_steps = rays[n * 3 + 2]; // empty ray, or ray that exceed max step count. if (num_steps == 0 || offset + num_steps > M) { weights_sum[index] = 0; depth[index] = 0; image[index * 3] = 0; image[index * 3 + 1] = 0; image[index * 3 + 2] = 0; return; } sigmas += offset; rgbs += offset * 3; deltas += offset * 2; // accumulate uint32_t step = 0; scalar_t T = 1.0f; scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0; while (step < num_steps) { const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); const scalar_t weight = alpha * T; r += weight * rgbs[0]; g += weight * rgbs[1]; b += weight * rgbs[2]; t += deltas[1]; // real delta d += weight * t; ws += weight; T *= 1.0f - alpha; // minimal remained transmittence if (T < T_thresh) break; //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); // locate sigmas++; rgbs += 3; deltas += 2; step++; } //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); // write weights_sum[index] = ws; // weights_sum depth[index] = d; image[index * 3] = r; image[index * 3 + 1] = g; image[index * 3 + 2] = b; } void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( sigmas.scalar_type(), "composite_rays_train_forward", ([&] { kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); })); } // grad_weights_sum: [N,] // grad: [N, 3] // sigmas: [M] // rgbs: [M, 3] // deltas: [M, 2] // rays: [N, 3], idx, offset, num_steps // weights_sum: [N,], weights_sum here // image: [N, 3] // grad_sigmas: [M] // grad_rgbs: [M, 3] template __global__ void kernel_composite_rays_train_backward( const scalar_t * __restrict__ grad_weights_sum, const scalar_t * __restrict__ grad_image, const scalar_t * __restrict__ sigmas, const scalar_t * __restrict__ rgbs, const scalar_t * __restrict__ deltas, const int * __restrict__ rays, const scalar_t * __restrict__ weights_sum, const scalar_t * __restrict__ image, const uint32_t M, const uint32_t N, const float T_thresh, scalar_t * grad_sigmas, scalar_t * grad_rgbs ) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate uint32_t index = rays[n * 3]; uint32_t offset = rays[n * 3 + 1]; uint32_t num_steps = rays[n * 3 + 2]; if (num_steps == 0 || offset + num_steps > M) return; grad_weights_sum += index; grad_image += index * 3; weights_sum += index; image += index * 3; sigmas += offset; rgbs += offset * 3; deltas += offset * 2; grad_sigmas += offset; grad_rgbs += offset * 3; // accumulate uint32_t step = 0; scalar_t T = 1.0f; const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; scalar_t r = 0, g = 0, b = 0, ws = 0; while (step < num_steps) { const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); const scalar_t weight = alpha * T; r += weight * rgbs[0]; g += weight * rgbs[1]; b += weight * rgbs[2]; ws += weight; T *= 1.0f - alpha; // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. // write grad_rgbs grad_rgbs[0] = grad_image[0] * weight; grad_rgbs[1] = grad_image[1] * weight; grad_rgbs[2] = grad_image[2] * weight; // write grad_sigmas grad_sigmas[0] = deltas[0] * ( grad_image[0] * (T * rgbs[0] - (r_final - r)) + grad_image[1] * (T * rgbs[1] - (g_final - g)) + grad_image[2] * (T * rgbs[2] - (b_final - b)) + grad_weights_sum[0] * (1 - ws_final) ); //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); // minimal remained transmittence if (T < T_thresh) break; // locate sigmas++; rgbs += 3; deltas += 2; grad_sigmas++; grad_rgbs += 3; step++; } } void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad_image.scalar_type(), "composite_rays_train_backward", ([&] { kernel_composite_rays_train_backward<<>>(grad_weights_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr()); })); } //////////////////////////////////////////////////// ///////////// infernce ///////////// //////////////////////////////////////////////////// template __global__ void kernel_march_rays( const uint32_t n_alive, const uint32_t n_step, const int* __restrict__ rays_alive, const scalar_t* __restrict__ rays_t, const scalar_t* __restrict__ rays_o, const scalar_t* __restrict__ rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const uint8_t * __restrict__ grid, const scalar_t* __restrict__ nears, const scalar_t* __restrict__ fars, scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas, const scalar_t* __restrict__ noises ) { const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= n_alive) return; const int index = rays_alive[n]; // ray id const float noise = noises[n]; // locate rays_o += index * 3; rays_d += index * 3; xyzs += n * n_step * 3; dirs += n * n_step * 3; deltas += n * n_step * 2; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; const float rH = 1 / (float)H; const float H3 = H * H * H; float t = rays_t[index]; // current ray's t const float near = nears[index], far = fars[index]; const float dt_min = 2 * SQRT3() / max_steps; const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; // march for n_step steps, record points uint32_t step = 0; // introduce some randomness t += clamp(t * dt_gamma, dt_min, dt_max) * noise; float last_t = t; while (t < far && step < n_step) { // current point const float x = clamp(ox + t * dx, -bound, bound); const float y = clamp(oy + t * dy, -bound, bound); const float z = clamp(oz + t * dz, -bound, bound); const float dt = clamp(t * dt_gamma, dt_min, dt_max); // get mip level const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] const float mip_bound = fminf(scalbnf(1, level), bound); const float mip_rbound = 1 / mip_bound; // convert to nearest grid position const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const uint32_t index = level * H3 + __morton3D(nx, ny, nz); const bool occ = grid[index / 8] & (1 << (index % 8)); // if occpuied, advance a small step, and write to output if (occ) { // write step xyzs[0] = x; xyzs[1] = y; xyzs[2] = z; dirs[0] = dx; dirs[1] = dy; dirs[2] = dz; // calc dt t += dt; deltas[0] = dt; deltas[1] = t - last_t; // used to calc depth last_t = t; // step xyzs += 3; dirs += 3; deltas += 2; step++; // else, skip a large step (basically skip a voxel grid) } else { // calc distance to next voxel const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); // step until next voxel do { t += clamp(t * dt_gamma, dt_min, dt_max); } while (t < tt); } } } void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "march_rays", ([&] { kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), noises.data_ptr()); })); } template __global__ void kernel_composite_rays( const uint32_t n_alive, const uint32_t n_step, const float T_thresh, int* rays_alive, scalar_t* rays_t, const scalar_t* __restrict__ sigmas, const scalar_t* __restrict__ rgbs, const scalar_t* __restrict__ deltas, scalar_t* weights_sum, scalar_t* depth, scalar_t* image ) { const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= n_alive) return; const int index = rays_alive[n]; // ray id // locate sigmas += n * n_step; rgbs += n * n_step * 3; deltas += n * n_step * 2; rays_t += index; weights_sum += index; depth += index; image += index * 3; scalar_t t = rays_t[0]; // current ray's t scalar_t weight_sum = weights_sum[0]; scalar_t d = depth[0]; scalar_t r = image[0]; scalar_t g = image[1]; scalar_t b = image[2]; // accumulate uint32_t step = 0; while (step < n_step) { // ray is terminated if delta == 0 if (deltas[0] == 0) break; const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); /* T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) w_i = alpha_i * T_i --> T_i = 1 - \sum_{j=0}^{i-1} w_j */ const scalar_t T = 1 - weight_sum; const scalar_t weight = alpha * T; weight_sum += weight; t += deltas[1]; // real delta d += weight * t; r += weight * rgbs[0]; g += weight * rgbs[1]; b += weight * rgbs[2]; //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); // ray is terminated if T is too small // use a larger bound to further accelerate inference if (T < T_thresh) break; // locate sigmas++; rgbs += 3; deltas += 2; step++; } //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); // rays_alive = -1 means ray is terminated early. if (step < n_step) { rays_alive[n] = -1; } else { rays_t[0] = t; } weights_sum[0] = weight_sum; // this is the thing I needed! depth[0] = d; image[0] = r; image[1] = g; image[2] = b; } void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( image.scalar_type(), "composite_rays", ([&] { kernel_composite_rays<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr()); })); } ================================================ FILE: 3drecon/raymarching/src/raymarching.h ================================================ #pragma once #include #include void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises); void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs); void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises); void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); ================================================ FILE: 3drecon/renderer/agg_net.py ================================================ import torch.nn.functional as F import torch.nn as nn import torch def weights_init(m): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: nn.init.zeros_(m.bias.data) class NeRF(nn.Module): def __init__(self, vol_n=8+8, feat_ch=8+16+32+3, hid_n=64): super(NeRF, self).__init__() self.hid_n = hid_n self.agg = Agg(feat_ch) self.lr0 = nn.Sequential(nn.Linear(vol_n+16, hid_n), nn.ReLU()) self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus()) self.color = nn.Sequential( nn.Linear(16+vol_n+feat_ch+hid_n+4, hid_n), # agg_feats+vox_feat+img_feat+lr0_feats+dir nn.ReLU(), nn.Linear(hid_n, 1) ) self.lr0.apply(weights_init) self.sigma.apply(weights_init) self.color.apply(weights_init) def forward(self, vox_feat, img_feat_rgb_dir, source_img_mask): # assert torch.sum(torch.sum(source_img_mask,1)<2)==0 b, d, n, _ = img_feat_rgb_dir.shape # b,d,n,f=8+16+32+3+4 agg_feat = self.agg(img_feat_rgb_dir, source_img_mask) # b,d,f=16 x = self.lr0(torch.cat((vox_feat, agg_feat), dim=-1)) # b,d,f=64 sigma = self.sigma(x) # b,d,1 x = torch.cat((x, vox_feat, agg_feat), dim=-1) # b,d,f=16+16+64 x = x.view(b, d, 1, x.shape[-1]).repeat(1, 1, n, 1) x = torch.cat((x, img_feat_rgb_dir), dim=-1) logits = self.color(x) source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0 logits[source_img_mask_] = -1e7 color_weight = F.softmax(logits, dim=-2) color = torch.sum((img_feat_rgb_dir[..., -7:-4] * color_weight), dim=-2) return color, sigma class Agg(nn.Module): def __init__(self, feat_ch): super(Agg, self).__init__() self.feat_ch = feat_ch self.view_fc = nn.Sequential(nn.Linear(4, feat_ch), nn.ReLU()) self.view_fc.apply(weights_init) self.global_fc = nn.Sequential(nn.Linear(feat_ch*3, 32), nn.ReLU()) self.agg_w_fc = nn.Linear(32, 1) self.fc = nn.Linear(32, 16) self.global_fc.apply(weights_init) self.agg_w_fc.apply(weights_init) self.fc.apply(weights_init) def masked_mean_var(self, img_feat_rgb, source_img_mask): # img_feat_rgb: b,d,n,f source_img_mask: b,n b, n = source_img_mask.shape source_img_mask = source_img_mask.view(b, 1, n, 1) mean = torch.sum(source_img_mask * img_feat_rgb, dim=-2)/ (torch.sum(source_img_mask, dim=-2) + 1e-5) var = torch.sum((img_feat_rgb - mean.unsqueeze(-2)) ** 2 * source_img_mask, dim=-2) / (torch.sum(source_img_mask, dim=-2) + 1e-5) return mean, var def forward(self, img_feat_rgb_dir, source_img_mask): # img_feat_rgb_dir b,d,n,f b, d, n, _ = img_feat_rgb_dir.shape view_feat = self.view_fc(img_feat_rgb_dir[..., -4:]) # b,d,n,f-4 img_feat_rgb = img_feat_rgb_dir[..., :-4] + view_feat mean_feat, var_feat = self.masked_mean_var(img_feat_rgb, source_img_mask) var_feat = var_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1) avg_feat = mean_feat.view(b, -1, 1, self.feat_ch).repeat(1, 1, n, 1) feat = torch.cat([img_feat_rgb, var_feat, avg_feat], dim=-1) # b,d,n,f global_feat = self.global_fc(feat) # b,d,n,f logits = self.agg_w_fc(global_feat) # b,d,n,1 source_img_mask_ = source_img_mask.reshape(b, 1, n, 1).repeat(1, logits.shape[1], 1, 1) == 0 logits[source_img_mask_] = -1e7 agg_w = F.softmax(logits, dim=-2) im_feat = (global_feat * agg_w).sum(dim=-2) return self.fc(im_feat) ================================================ FILE: 3drecon/renderer/cost_reg_net.py ================================================ import torch.nn as nn class ConvBnReLU3D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=nn.BatchNorm3d): super(ConvBnReLU3D, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = norm_act(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): return self.relu(self.bn(self.conv(x))) class CostRegNet(nn.Module): def __init__(self, in_channels, norm_act=nn.BatchNorm3d): super(CostRegNet, self).__init__() self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act) self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act) self.conv7 = nn.Sequential( nn.ConvTranspose3d(64, 32, 3, padding=1, output_padding=1, stride=2, bias=False), norm_act(32) ) self.conv9 = nn.Sequential( nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, stride=2, bias=False), norm_act(16) ) self.conv11 = nn.Sequential( nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,stride=2, bias=False), norm_act(8) ) self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False)) self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False)) def forward(self, x): conv0 = self.conv0(x) conv2 = self.conv2(self.conv1(conv0)) conv4 = self.conv4(self.conv3(conv2)) x = self.conv6(self.conv5(conv4)) x = conv4 + self.conv7(x) del conv4 x = conv2 + self.conv9(x) del conv2 x = conv0 + self.conv11(x) del conv0 feat = self.feat_conv(x) depth = self.depth_conv(x) return feat, depth class MinCostRegNet(nn.Module): def __init__(self, in_channels, norm_act=nn.BatchNorm3d): super(MinCostRegNet, self).__init__() self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) self.conv9 = nn.Sequential( nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, stride=2, bias=False), norm_act(16)) self.conv11 = nn.Sequential( nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1, stride=2, bias=False), norm_act(8)) self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False)) self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False)) def forward(self, x): conv0 = self.conv0(x) conv2 = self.conv2(self.conv1(conv0)) conv4 = self.conv4(self.conv3(conv2)) x = conv4 x = conv2 + self.conv9(x) del conv2 x = conv0 + self.conv11(x) del conv0 feat = self.feat_conv(x) depth = self.depth_conv(x) return feat, depth ================================================ FILE: 3drecon/renderer/dummy_dataset.py ================================================ import pytorch_lightning as pl from torch.utils.data import Dataset import webdataset as wds from torch.utils.data.distributed import DistributedSampler class DummyDataset(pl.LightningDataModule): def __init__(self,seed): super().__init__() def setup(self, stage): if stage in ['fit']: self.train_dataset = DummyData(True) self.val_dataset = DummyData(False) else: raise NotImplementedError def train_dataloader(self): return wds.WebLoader(self.train_dataset, batch_size=1, num_workers=0, shuffle=False) def val_dataloader(self): return wds.WebLoader(self.val_dataset, batch_size=1, num_workers=0, shuffle=False) def test_dataloader(self): return wds.WebLoader(DummyData(False)) class DummyData(Dataset): def __init__(self,is_train): self.is_train=is_train def __len__(self): if self.is_train: return 99999999 else: return 1 def __getitem__(self, index): return {} ================================================ FILE: 3drecon/renderer/feature_net.py ================================================ import torch.nn as nn import torch.nn.functional as F class ConvBnReLU(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1, norm_act=nn.BatchNorm2d): super(ConvBnReLU, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = norm_act(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): return self.relu(self.bn(self.conv(x))) class FeatureNet(nn.Module): def __init__(self, norm_act=nn.BatchNorm2d): super(FeatureNet, self).__init__() self.conv0 = nn.Sequential(ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act)) self.conv1 = nn.Sequential(ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act)) self.conv2 = nn.Sequential(ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act)) self.toplayer = nn.Conv2d(32, 32, 1) self.lat1 = nn.Conv2d(16, 32, 1) self.lat0 = nn.Conv2d(8, 32, 1) self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) def _upsample_add(self, x, y): return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y def forward(self, x): conv0 = self.conv0(x) conv1 = self.conv1(conv0) conv2 = self.conv2(conv1) feat2 = self.toplayer(conv2) feat1 = self._upsample_add(feat2, self.lat1(conv1)) feat0 = self._upsample_add(feat1, self.lat0(conv0)) feat1 = self.smooth1(feat1) feat0 = self.smooth0(feat0) return feat2, feat1, feat0 ================================================ FILE: 3drecon/renderer/neus_networks.py ================================================ import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import tinycudann as tcnn # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs['input_dims'] out_dim = 0 if self.kwargs['include_input']: embed_fns.append(lambda x: x) out_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] if self.kwargs['log_sampling']: freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) else: freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs) for freq in freq_bands: for p_fn in self.kwargs['periodic_fns']: embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) def get_embedder(multires, input_dims=3): embed_kwargs = { 'include_input': True, 'input_dims': input_dims, 'max_freq_log2': multires - 1, 'num_freqs': multires, 'log_sampling': True, 'periodic_fns': [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) def embed(x, eo=embedder_obj): return eo.embed(x) return embed, embedder_obj.out_dim class SDFNetwork(nn.Module): def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5, scale=1, geometric_init=True, weight_norm=True, inside_outside=False): super(SDFNetwork, self).__init__() dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] self.embed_fn_fine = None if multires > 0: embed_fn, input_ch = get_embedder(multires, input_dims=d_in) self.embed_fn_fine = embed_fn dims[0] = input_ch self.num_layers = len(dims) self.skip_in = skip_in self.scale = scale for l in range(0, self.num_layers - 1): if l + 1 in self.skip_in: out_dim = dims[l + 1] - dims[0] else: out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if geometric_init: if l == self.num_layers - 2: if not inside_outside: torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, -bias) else: torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, bias) elif multires > 0 and l == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) elif multires > 0 and l in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.activation = nn.Softplus(beta=100) def forward(self, inputs): inputs = inputs * self.scale if self.embed_fn_fine is not None: inputs = self.embed_fn_fine(inputs) x = inputs for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) if l in self.skip_in: x = torch.cat([x, inputs], -1) / np.sqrt(2) x = lin(x) if l < self.num_layers - 2: x = self.activation(x) return x def sdf(self, x): return self.forward(x)[..., :1] def sdf_hidden_appearance(self, x): return self.forward(x) def gradient(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return gradients def sdf_normal(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return y[..., :1].detach(), gradients.detach() class SDFNetworkWithFeature(nn.Module): def __init__(self, cube, dp_in, df_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0, bias=0.5, scale=1, geometric_init=True, weight_norm=True, inside_outside=False, cube_length=0.5): super().__init__() self.register_buffer("cube", cube) self.cube_length = cube_length dims = [dp_in+df_in] + [d_hidden for _ in range(n_layers)] + [d_out] self.embed_fn_fine = None if multires > 0: embed_fn, input_ch = get_embedder(multires, input_dims=dp_in) self.embed_fn_fine = embed_fn dims[0] = input_ch + df_in self.num_layers = len(dims) self.skip_in = skip_in self.scale = scale for l in range(0, self.num_layers - 1): if l + 1 in self.skip_in: out_dim = dims[l + 1] - dims[0] else: out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if geometric_init: if l == self.num_layers - 2: if not inside_outside: torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, -bias) else: torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, bias) elif multires > 0 and l == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) elif multires > 0 and l in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.activation = nn.Softplus(beta=100) def forward(self, points): points = points * self.scale # note: point*2 because the cube is [-0.5,0.5] with torch.no_grad(): feats = F.grid_sample(self.cube, points.view(1,-1,1,1,3)/self.cube_length, mode='bilinear', align_corners=True, padding_mode='zeros').detach() feats = feats.view(self.cube.shape[1], -1).permute(1,0).view(*points.shape[:-1], -1) if self.embed_fn_fine is not None: points = self.embed_fn_fine(points) x = torch.cat([points, feats], -1) for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) if l in self.skip_in: x = torch.cat([x, points, feats], -1) / np.sqrt(2) x = lin(x) if l < self.num_layers - 2: x = self.activation(x) # concat feats x = torch.cat([x, feats], -1) return x def sdf(self, x): return self.forward(x)[..., :1] def sdf_hidden_appearance(self, x): return self.forward(x) def gradient(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return gradients def sdf_normal(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return y[..., :1].detach(), gradients.detach() class VanillaMLP(nn.Module): def __init__(self, dim_in, dim_out, n_neurons, n_hidden_layers): super().__init__() self.n_neurons, self.n_hidden_layers = n_neurons, n_hidden_layers self.sphere_init, self.weight_norm = True, True self.sphere_init_radius = 0.5 self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()] for i in range(self.n_hidden_layers - 1): self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()] self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] self.layers = nn.Sequential(*self.layers) @torch.cuda.amp.autocast(False) def forward(self, x): x = self.layers(x.float()) return x def make_linear(self, dim_in, dim_out, is_first, is_last): layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality if self.sphere_init: if is_last: torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001) elif is_first: torch.nn.init.constant_(layer.bias, 0.0) torch.nn.init.constant_(layer.weight[:, 3:], 0.0) torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out)) else: torch.nn.init.constant_(layer.bias, 0.0) torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) else: torch.nn.init.constant_(layer.bias, 0.0) torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') if self.weight_norm: layer = nn.utils.weight_norm(layer) return layer def make_activation(self): if self.sphere_init: return nn.Softplus(beta=100) else: return nn.ReLU(inplace=True) class SDFHashGridNetwork(nn.Module): def __init__(self, bound=0.5, feats_dim=13): super().__init__() self.bound = bound # max_resolution = 32 # base_resolution = 16 # n_levels = 4 # log2_hashmap_size = 16 # n_features_per_level = 8 max_resolution = 2048 base_resolution = 16 n_levels = 16 log2_hashmap_size = 19 n_features_per_level = 2 # max_res = base_res * t^(k-1) per_level_scale = (max_resolution / base_resolution)** (1 / (n_levels - 1)) self.encoder = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "HashGrid", "n_levels": n_levels, "n_features_per_level": n_features_per_level, "log2_hashmap_size": log2_hashmap_size, "base_resolution": base_resolution, "per_level_scale": per_level_scale, }, ) self.sdf_mlp = VanillaMLP(n_levels*n_features_per_level+3,feats_dim,64,1) def forward(self, x): shape = x.shape[:-1] x = x.reshape(-1, 3) x_ = (x + self.bound) / (2 * self.bound) feats = self.encoder(x_) feats = torch.cat([x, feats], 1) feats = self.sdf_mlp(feats) feats = feats.reshape(*shape,-1) return feats def sdf(self, x): return self(x)[...,:1] def gradient(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return gradients def sdf_normal(self, x): x.requires_grad_(True) with torch.enable_grad(): y = self.sdf(x) d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return y[..., :1].detach(), gradients.detach() class RenderingFFNetwork(nn.Module): def __init__(self, in_feats_dim=12): super().__init__() self.dir_encoder = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "SphericalHarmonics", "degree": 4, }, ) self.color_mlp = tcnn.Network( n_input_dims = in_feats_dim + 3 + self.dir_encoder.n_output_dims, n_output_dims = 3, network_config={ "otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "none", "n_neurons": 64, "n_hidden_layers": 2, }, ) def forward(self, points, normals, view_dirs, feature_vectors): normals = F.normalize(normals, dim=-1) view_dirs = F.normalize(view_dirs, dim=-1) reflective = torch.sum(view_dirs * normals, -1, keepdim=True) * normals * 2 - view_dirs x = torch.cat([feature_vectors, normals, self.dir_encoder(reflective)], -1) colors = self.color_mlp(x).float() colors = F.sigmoid(colors) return colors # This implementation is borrowed from IDR: https://github.com/lioryariv/idr class RenderingNetwork(nn.Module): def __init__(self, d_feature, d_in, d_out, d_hidden, n_layers, weight_norm=True, multires_view=0, squeeze_out=True, use_view_dir=True): super().__init__() self.squeeze_out = squeeze_out self.rgb_act=F.sigmoid self.use_view_dir=use_view_dir dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] self.embedview_fn = None if multires_view > 0: embedview_fn, input_ch = get_embedder(multires_view) self.embedview_fn = embedview_fn dims[0] += (input_ch - 3) self.num_layers = len(dims) for l in range(0, self.num_layers - 1): out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.relu = nn.ReLU() def forward(self, points, normals, view_dirs, feature_vectors): if self.use_view_dir: view_dirs = F.normalize(view_dirs, dim=-1) normals = F.normalize(normals, dim=-1) reflective = torch.sum(view_dirs*normals, -1, keepdim=True) * normals * 2 - view_dirs if self.embedview_fn is not None: reflective = self.embedview_fn(reflective) rendering_input = torch.cat([points, reflective, normals, feature_vectors], dim=-1) else: rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) x = rendering_input for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) x = lin(x) if l < self.num_layers - 2: x = self.relu(x) if self.squeeze_out: x = self.rgb_act(x) return x class SingleVarianceNetwork(nn.Module): def __init__(self, init_val, activation='exp'): super(SingleVarianceNetwork, self).__init__() self.act = activation self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) def forward(self, x): device = x.device if self.act=='exp': return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * torch.exp(self.variance * 10.0) else: raise NotImplementedError def warp(self, x, inv_s): device = x.device return torch.ones([*x.shape[:-1], 1], dtype=torch.float32, device=device) * inv_s ================================================ FILE: 3drecon/renderer/ngp_renderer.py ================================================ import math import trimesh import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from packaging import version as pver import tinycudann as tcnn from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd import raymarching def custom_meshgrid(*args): # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid if pver.parse(torch.__version__) < pver.parse('1.10'): return torch.meshgrid(*args) else: return torch.meshgrid(*args, indexing='ij') def sample_pdf(bins, weights, n_samples, det=False): # This implementation is from NeRF # bins: [B, T], old_z_vals # weights: [B, T - 1], bin weights. # return: [B, n_samples], new_z_vals # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # Take uniform samples if det: u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) u = u.expand(list(cdf.shape[:-1]) + [n_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) # Invert CDF u = u.contiguous() inds = torch.searchsorted(cdf, u, right=True) below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples def plot_pointcloud(pc, color=None): # pc: [N, 3] # color: [N, 3/4] print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) pc = trimesh.PointCloud(pc, color) # axis axes = trimesh.creation.axis(axis_length=4) # sphere sphere = trimesh.creation.icosphere(radius=1) trimesh.Scene([pc, axes, sphere]).show() class NGPRenderer(nn.Module): def __init__(self, bound=1, cuda_ray=True, density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance. min_near=0.2, density_thresh=0.01, bg_radius=-1, ): super().__init__() self.bound = bound self.cascade = 1 self.grid_size = 128 self.density_scale = density_scale self.min_near = min_near self.density_thresh = density_thresh self.bg_radius = bg_radius # radius of the background sphere. # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound]) aabb_infer = aabb_train.clone() self.register_buffer('aabb_train', aabb_train) self.register_buffer('aabb_infer', aabb_infer) # extra state for cuda raymarching self.cuda_ray = cuda_ray if cuda_ray: # density grid density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] self.register_buffer('density_grid', density_grid) self.register_buffer('density_bitfield', density_bitfield) self.mean_density = 0 self.iter_density = 0 # step counter step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... self.register_buffer('step_counter', step_counter) self.mean_count = 0 self.local_step = 0 def forward(self, x, d): raise NotImplementedError() # separated density and color query (can accelerate non-cuda-ray mode.) def density(self, x): raise NotImplementedError() def color(self, x, d, mask=None, **kwargs): raise NotImplementedError() def reset_extra_state(self): if not self.cuda_ray: return # density grid self.density_grid.zero_() self.mean_density = 0 self.iter_density = 0 # step counter self.step_counter.zero_() self.mean_count = 0 self.local_step = 0 def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # bg_color: [3] in range [0, 1] # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # choose aabb aabb = self.aabb_train if self.training else self.aabb_infer # sample steps nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near) nears.unsqueeze_(-1) fars.unsqueeze_(-1) #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T] z_vals = z_vals.expand((N, num_steps)) # [N, T] z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] # perturb z_vals sample_dist = (fars - nears) / num_steps if perturb: z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. # generate xyzs xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) # query SDF and RGB density_outputs = self.density(xyzs.reshape(-1, 3)) #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] for k, v in density_outputs.items(): density_outputs[k] = v.view(N, num_steps, -1) # upsample z_vals (nerf-like) if upsample_steps > 0: with torch.no_grad(): deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T] alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] # sample new z_vals z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t] new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. # only forward new points to save computation new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] for k, v in new_density_outputs.items(): new_density_outputs[k] = v.view(N, upsample_steps, -1) # re-order z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] z_vals, z_index = torch.sort(z_vals, dim=1) xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) for k in density_outputs: tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t] alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) for k, v in density_outputs.items(): density_outputs[k] = v.view(-1, v.shape[-1]) mask = weights > 1e-4 # hard coded rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs) rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] #print(xyzs.shape, 'valid_rgb:', mask.sum().item()) # calculate weight_sum (mask) weights_sum = weights.sum(dim=-1) # [N] # calculate depth ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) depth = torch.sum(weights * ori_z_vals, dim=-1) # calculate color image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3] elif bg_color is None: bg_color = 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color image = image.view(*prefix, 3) depth = depth.view(*prefix) # tmp: reg loss in mip-nerf 360 # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1) # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T] # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum() return { 'depth': depth, 'image': image, 'weights_sum': weights_sum, } def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # pre-calculate near far nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near) # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] bg_color = self.background(sph, rays_d) # [N, 3] elif bg_color is None: bg_color = 1 results = {} if self.training: # setup counter counter = self.step_counter[self.local_step % 16] counter.zero_() # set to 0 self.local_step += 1 xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) sigmas, rgbs = self(xyzs, dirs) sigmas = self.density_scale * sigmas weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh) image = image + (1 - weights_sum).unsqueeze(-1) * bg_color depth = torch.clamp(depth - nears, min=0) / (fars - nears) image = image.view(*prefix, 3) depth = depth.view(*prefix) else: # allocate outputs # if use autocast, must init as half so it won't be autocasted and lose reference. #dtype = torch.half if torch.is_autocast_enabled() else torch.float32 # output should always be float32! only network inference uses half. dtype = torch.float32 weights_sum = torch.zeros(N, dtype=dtype, device=device) depth = torch.zeros(N, dtype=dtype, device=device) image = torch.zeros(N, 3, dtype=dtype, device=device) n_alive = N rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] rays_t = nears.clone() # [N] step = 0 while step < max_steps: # count alive rays n_alive = rays_alive.shape[0] # exit loop if n_alive <= 0: break # decide compact_steps n_step = max(min(N // n_alive, 8), 1) xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps) sigmas, rgbs = self(xyzs, dirs) # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. # sigmas = density_outputs['sigma'] # rgbs = self.color(xyzs, dirs, **density_outputs) sigmas = self.density_scale * sigmas raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh) rays_alive = rays_alive[rays_alive >= 0] #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') step += n_step image = image + (1 - weights_sum).unsqueeze(-1) * bg_color depth = torch.clamp(depth - nears, min=0) / (fars - nears) image = image.view(*prefix, 3) depth = depth.view(*prefix) results['weights_sum'] = weights_sum results['depth'] = depth results['image'] = image return results @torch.no_grad() def mark_untrained_grid(self, poses, intrinsic, S=64): # poses: [B, 4, 4] # intrinsic: [3, 3] if not self.cuda_ray: return if isinstance(poses, np.ndarray): poses = torch.from_numpy(poses) B = poses.shape[0] fx, fy, cx, cy = intrinsic X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) count = torch.zeros_like(self.density_grid) poses = poses.to(count.device) # 5-level loop, forgive me... for xs in X: for ys in Y: for zs in Z: # construct points xx, yy, zz = custom_meshgrid(xs, ys, zs) coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1] # cascading for cas in range(self.cascade): bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_world_xyzs = world_xyzs * (bound - half_grid_size) # split batch to avoid OOM head = 0 while head < B: tail = min(head + S, B) # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.) cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1) cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3] # query if point is covered by any camera mask_z = cam_xyzs[:, :, 2] > 0 # [S, N] mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2 mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2 mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N] # update count count[cas, indices] += mask head += S # mark untrained grid as -1 self.density_grid[count == 0] = -1 print(f'[mark untrained grid] {(count == 0).sum()} from {self.grid_size ** 3 * self.cascade}') @torch.no_grad() def update_extra_state(self, decay=0.95, S=128): # call before each epoch to update extra states. if not self.cuda_ray: return ### update density grid tmp_grid = - torch.ones_like(self.density_grid) # full update. if self.iter_density < 16: #if True: X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) for xs in X: for ys in Y: for zs in Z: # construct points xx, yy, zz = custom_meshgrid(xs, ys, zs) coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] # cascading for cas in range(self.cascade): bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_xyzs = xyzs * (bound - half_grid_size) # add noise in [-hgs, hgs] cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size # query density sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() sigmas *= self.density_scale # assign tmp_grid[cas, indices] = sigmas # partial update (half the computation) # TODO: why no need of maxpool ? else: N = self.grid_size ** 3 // 4 # H * H * H / 4 for cas in range(self.cascade): # random sample some positions coords = torch.randint(0, self.grid_size, (N, 3), device=self.density_bitfield.device) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] # random sample occupied positions occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(-1) # [Nz] rand_mask = torch.randint(0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_bitfield.device) occ_indices = occ_indices[rand_mask] # [Nz] --> [N], allow for duplication occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3] # concat indices = torch.cat([indices, occ_indices], dim=0) coords = torch.cat([coords, occ_coords], dim=0) # same below xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_xyzs = xyzs * (bound - half_grid_size) # add noise in [-hgs, hgs] cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size # query density sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() sigmas *= self.density_scale # assign tmp_grid[cas, indices] = sigmas ## max-pool on tmp_grid for less aggressive culling [No significant improvement...] # invalid_mask = tmp_grid < 0 # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1) # tmp_grid[invalid_mask] = -1 # ema update valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 regions are viewed as 0 density. #self.mean_density = torch.mean(self.density_grid[self.density_grid > 0]).item() # do not count -1 regions self.iter_density += 1 # convert to bitfield density_thresh = min(self.mean_density, self.density_thresh) self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) ### update step counter total_step = min(16, self.local_step) if total_step > 0: self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) self.local_step = 0 #print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: pred_rgb: [B, N, 3] if self.cuda_ray: _run = self.run_cuda else: _run = self.run results = _run(rays_o, rays_d, **kwargs) return results class _trunc_exp(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # cast to float32 def forward(ctx, x): ctx.save_for_backward(x) return torch.exp(x) @staticmethod @custom_bwd def backward(ctx, g): x = ctx.saved_tensors[0] return g * torch.exp(x.clamp(-15, 15)) trunc_exp = _trunc_exp.apply class NGPNetwork(NGPRenderer): def __init__(self, num_layers=2, hidden_dim=64, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64, bound=0.5, max_resolution=128, base_resolution=16, n_levels=16, **kwargs ): super().__init__(bound, **kwargs) # sigma network self.num_layers = num_layers self.hidden_dim = hidden_dim self.geo_feat_dim = geo_feat_dim self.bound = bound log2_hashmap_size = 19 n_features_per_level = 2 per_level_scale = np.exp2(np.log2(max_resolution / base_resolution) / (n_levels - 1)) self.encoder = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "HashGrid", "n_levels": n_levels, "n_features_per_level": n_features_per_level, "log2_hashmap_size": log2_hashmap_size, "base_resolution": base_resolution, "per_level_scale": per_level_scale, }, ) self.sigma_net = tcnn.Network( n_input_dims = n_levels * 2, n_output_dims=1 + self.geo_feat_dim, network_config={ "otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "None", "n_neurons": hidden_dim, "n_hidden_layers": num_layers - 1, }, ) # color network self.num_layers_color = num_layers_color self.hidden_dim_color = hidden_dim_color self.encoder_dir = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "SphericalHarmonics", "degree": 4, }, ) self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim self.color_net = tcnn.Network( n_input_dims = self.in_dim_color, n_output_dims=3, network_config={ "otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "None", "n_neurons": hidden_dim_color, "n_hidden_layers": num_layers_color - 1, }, ) self.density_scale, self.density_std = 10.0, 0.25 def forward(self, x, d): # x: [N, 3], in [-bound, bound] # d: [N, 3], nomalized in [-1, 1] # sigma x_raw = x x = (x + self.bound) / (2 * self.bound) # to [0, 1] x = self.encoder(x) h = self.sigma_net(x) # sigma = F.relu(h[..., 0]) density = h[..., 0] # add density bias dist = torch.norm(x_raw, dim=-1) density_bias = (1 - dist / self.density_std) * self.density_scale density = density_bias + density sigma = F.softplus(density) geo_feat = h[..., 1:] # color d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] d = self.encoder_dir(d) # p = torch.zeros_like(geo_feat[..., :1]) # manual input padding h = torch.cat([d, geo_feat], dim=-1) h = self.color_net(h) # sigmoid activation for rgb color = torch.sigmoid(h) return sigma, color def density(self, x): # x: [N, 3], in [-bound, bound] x_raw = x x = (x + self.bound) / (2 * self.bound) # to [0, 1] x = self.encoder(x) h = self.sigma_net(x) # sigma = F.relu(h[..., 0]) density = h[..., 0] # add density bias dist = torch.norm(x_raw, dim=-1) density_bias = (1 - dist / self.density_std) * self.density_scale density = density_bias + density sigma = F.softplus(density) geo_feat = h[..., 1:] return { 'sigma': sigma, 'geo_feat': geo_feat, } # allow masked inference def color(self, x, d, mask=None, geo_feat=None, **kwargs): # x: [N, 3] in [-bound, bound] # mask: [N,], bool, indicates where we actually needs to compute rgb. x = (x + self.bound) / (2 * self.bound) # to [0, 1] if mask is not None: rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] # in case of empty mask if not mask.any(): return rgbs x = x[mask] d = d[mask] geo_feat = geo_feat[mask] # color d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] d = self.encoder_dir(d) h = torch.cat([d, geo_feat], dim=-1) h = self.color_net(h) # sigmoid activation for rgb h = torch.sigmoid(h) if mask is not None: rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 else: rgbs = h return rgbs ================================================ FILE: 3drecon/renderer/renderer.py ================================================ import abc import os from pathlib import Path import cv2 import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F from omegaconf import OmegaConf from skimage.io import imread, imsave from PIL import Image from torch.optim.lr_scheduler import LambdaLR from renderer.neus_networks import SDFNetwork, RenderingNetwork, SingleVarianceNetwork, SDFHashGridNetwork, RenderingFFNetwork from renderer.ngp_renderer import NGPNetwork from util import instantiate_from_config, read_pickle, concat_images_list DEFAULT_RADIUS = np.sqrt(3)/2 DEFAULT_SIDE_LENGTH = 0.6 def sample_pdf(bins, weights, n_samples, det=True): device = bins.device dtype = bins.dtype # This implementation is from NeRF # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # Take uniform samples if det: u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples, dtype=dtype, device=device) u = u.expand(list(cdf.shape[:-1]) + [n_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [n_samples], dtype=dtype, device=device) # Invert CDF u = u.contiguous() inds = torch.searchsorted(cdf, u, right=True) below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples def near_far_from_sphere(rays_o, rays_d, radius=DEFAULT_RADIUS): a = torch.sum(rays_d ** 2, dim=-1, keepdim=True) b = torch.sum(rays_o * rays_d, dim=-1, keepdim=True) mid = -b / a near = mid - radius far = mid + radius return near, far class BackgroundRemoval: def __init__(self, device='cuda'): from carvekit.api.high import HiInterface self.interface = HiInterface( object_type="object", # Can be "object" or "hairs-like". batch_size_seg=5, batch_size_matting=1, device=device, seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net matting_mask_size=2048, trimap_prob_threshold=231, trimap_dilation=30, trimap_erosion_iters=5, fp16=True, ) @torch.no_grad() def __call__(self, image): # image: [H, W, 3] array in [0, 255]. image = Image.fromarray(image) image = self.interface([image])[0] image = np.array(image) return image class BaseRenderer(nn.Module): def __init__(self, train_batch_num, test_batch_num): super().__init__() self.train_batch_num = train_batch_num self.test_batch_num = test_batch_num @abc.abstractmethod def render_impl(self, ray_batch, is_train, step): pass @abc.abstractmethod def render_with_loss(self, ray_batch, is_train, step): pass def render(self, ray_batch, is_train, step): batch_num = self.train_batch_num if is_train else self.test_batch_num ray_num = ray_batch['rays_o'].shape[0] outputs = {} for ri in range(0, ray_num, batch_num): cur_ray_batch = {} for k, v in ray_batch.items(): cur_ray_batch[k] = v[ri:ri + batch_num] cur_outputs = self.render_impl(cur_ray_batch, is_train, step) for k, v in cur_outputs.items(): if k not in outputs: outputs[k] = [] outputs[k].append(v) for k, v in outputs.items(): outputs[k] = torch.cat(v, 0) return outputs class NeuSRenderer(BaseRenderer): def __init__(self, train_batch_num, test_batch_num, lambda_eikonal_loss=0.1, use_mask=True, lambda_rgb_loss=1.0, lambda_mask_loss=0.0, rgb_loss='soft_l1', coarse_sn=64, fine_sn=64): super().__init__(train_batch_num, test_batch_num) self.n_samples = coarse_sn self.n_importance = fine_sn self.up_sample_steps = 4 self.anneal_end = 200 self.use_mask = use_mask self.lambda_eikonal_loss = lambda_eikonal_loss self.lambda_rgb_loss = lambda_rgb_loss self.lambda_mask_loss = lambda_mask_loss self.rgb_loss = rgb_loss self.sdf_network = SDFNetwork(d_out=257, d_in=3, d_hidden=256, n_layers=8, skip_in=[4], multires=6, bias=0.5, scale=1.0, geometric_init=True, weight_norm=True) self.color_network = RenderingNetwork(d_feature=256, d_in=9, d_out=3, d_hidden=256, n_layers=4, weight_norm=True, multires_view=4, squeeze_out=True) self.default_dtype = torch.float32 self.deviation_network = SingleVarianceNetwork(0.3) @torch.no_grad() def get_vertex_colors(self, vertices): """ @param vertices: n,3 @return: """ V = vertices.shape[0] bn = 20480 verts_colors = [] with torch.no_grad(): for vi in range(0, V, bn): verts = torch.from_numpy(vertices[vi:vi+bn].astype(np.float32)).cuda() feats = self.sdf_network(verts)[..., 1:] gradients = self.sdf_network.gradient(verts) # ...,3 gradients = F.normalize(gradients, dim=-1) colors = self.color_network(verts, gradients, gradients, feats) colors = torch.clamp(colors,min=0,max=1).cpu().numpy() verts_colors.append(colors) verts_colors = (np.concatenate(verts_colors, 0)*255).astype(np.uint8) return verts_colors def upsample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s): """ Up sampling give a fixed inv_s """ device = rays_o.device batch_size, n_samples = z_vals.shape pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3 inner_mask = self.get_inner_mask(pts) # radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False) inside_sphere = inner_mask[:, :-1] | inner_mask[:, 1:] sdf = sdf.reshape(batch_size, n_samples) prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] mid_sdf = (prev_sdf + next_sdf) * 0.5 cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) prev_cos_val = torch.cat([torch.zeros([batch_size, 1], dtype=self.default_dtype, device=device), cos_val[:, :-1]], dim=-1) cos_val = torch.stack([prev_cos_val, cos_val], dim=-1) cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False) cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere dist = (next_z_vals - prev_z_vals) prev_esti_sdf = mid_sdf - cos_val * dist * 0.5 next_esti_sdf = mid_sdf + cos_val * dist * 0.5 prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s) next_cdf = torch.sigmoid(next_esti_sdf * inv_s) alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) weights = alpha * torch.cumprod( torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[:, :-1] z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() return z_samples def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False): batch_size, n_samples = z_vals.shape _, n_importance = new_z_vals.shape pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] z_vals = torch.cat([z_vals, new_z_vals], dim=-1) z_vals, index = torch.sort(z_vals, dim=-1) if not last: device = pts.device new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance) sdf = torch.cat([sdf, new_sdf], dim=-1) xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1).to(device) index = index.reshape(-1) sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) return z_vals, sdf def sample_depth(self, rays_o, rays_d, near, far, perturb): n_samples = self.n_samples n_importance = self.n_importance up_sample_steps = self.up_sample_steps device = rays_o.device # sample points batch_size = len(rays_o) z_vals = torch.linspace(0.0, 1.0, n_samples, dtype=self.default_dtype, device=device) # sn z_vals = near + (far - near) * z_vals[None, :] # rn,sn if perturb > 0: t_rand = (torch.rand([batch_size, 1]).to(device) - 0.5) z_vals = z_vals + t_rand * 2.0 / n_samples # Up sample with torch.no_grad(): pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] sdf = self.sdf_network.sdf(pts).reshape(batch_size, n_samples) for i in range(up_sample_steps): rn, sn = z_vals.shape inv_s = torch.ones(rn, sn - 1, dtype=self.default_dtype, device=device) * 64 * 2 ** i new_z_vals = self.upsample(rays_o, rays_d, z_vals, sdf, n_importance // up_sample_steps, inv_s) z_vals, sdf = self.cat_z_vals(rays_o, rays_d, z_vals, new_z_vals, sdf, last=(i + 1 == up_sample_steps)) return z_vals def compute_sdf_alpha(self, points, dists, dirs, cos_anneal_ratio, step): # points [...,3] dists [...] dirs[...,3] sdf_nn_output = self.sdf_network(points) sdf = sdf_nn_output[..., 0] feature_vector = sdf_nn_output[..., 1:] gradients = self.sdf_network.gradient(points) # ...,3 inv_s = self.deviation_network(points).clip(1e-6, 1e6) # ...,1 inv_s = inv_s[..., 0] true_cos = (dirs * gradients).sum(-1) # [...] iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + F.relu(-true_cos) * cos_anneal_ratio) # always non-positive # Estimate signed distances at section points estimated_next_sdf = sdf + iter_cos * dists * 0.5 estimated_prev_sdf = sdf - iter_cos * dists * 0.5 prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) p = prev_cdf - next_cdf c = prev_cdf alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) # [...] return alpha, gradients, feature_vector, inv_s, sdf def get_anneal_val(self, step): if self.anneal_end < 0: return 1.0 else: return np.min([1.0, step / self.anneal_end]) def get_inner_mask(self, points): return torch.sum(torch.abs(points)<=DEFAULT_SIDE_LENGTH,-1)==3 def render_impl(self, ray_batch, is_train, step): near, far = near_far_from_sphere(ray_batch['rays_o'], ray_batch['rays_d']) rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d'] z_vals = self.sample_depth(rays_o, rays_d, near, far, is_train) batch_size, n_samples = z_vals.shape # section length in original space dists = z_vals[..., 1:] - z_vals[..., :-1] # rn,sn-1 dists = torch.cat([dists, dists[..., -1:]], -1) # rn,sn mid_z_vals = z_vals + dists * 0.5 points = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * mid_z_vals.unsqueeze(-1) # rn, sn, 3 inner_mask = self.get_inner_mask(points) dirs = rays_d.unsqueeze(-2).expand(batch_size, n_samples, 3) dirs = F.normalize(dirs, dim=-1) device = rays_o.device alpha, sampled_color, gradient_error, normal = torch.zeros(batch_size, n_samples, dtype=self.default_dtype, device=device), \ torch.zeros(batch_size, n_samples, 3, dtype=self.default_dtype, device=device), \ torch.zeros([batch_size, n_samples], dtype=self.default_dtype, device=device), \ torch.zeros([batch_size, n_samples, 3], dtype=self.default_dtype, device=device) if torch.sum(inner_mask) > 0: cos_anneal_ratio = self.get_anneal_val(step) if is_train else 1.0 alpha[inner_mask], gradients, feature_vector, inv_s, sdf = self.compute_sdf_alpha(points[inner_mask], dists[inner_mask], dirs[inner_mask], cos_anneal_ratio, step) sampled_color[inner_mask] = self.color_network(points[inner_mask], gradients, -dirs[inner_mask], feature_vector) # Eikonal loss gradient_error[inner_mask] = (torch.linalg.norm(gradients, ord=2, dim=-1) - 1.0) ** 2 # rn,sn normal[inner_mask] = F.normalize(gradients, dim=-1) weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1], dtype=self.default_dtype, device=device), 1. - alpha + 1e-7], -1), -1)[..., :-1] # rn,sn mask = torch.sum(weights,dim=1).unsqueeze(-1) # rn,1 color = (sampled_color * weights[..., None]).sum(dim=1) + (1 - mask) # add white background normal = (normal * weights[..., None]).sum(dim=1) outputs = { 'rgb': color, # rn,3 'gradient_error': gradient_error, # rn,sn 'inner_mask': inner_mask, # rn,sn 'normal': normal, # rn,3 'mask': mask, # rn,1 } return outputs def render_with_loss(self, ray_batch, is_train, step): render_outputs = self.render(ray_batch, is_train, step) rgb_gt = ray_batch['rgb'] rgb_pr = render_outputs['rgb'] if self.rgb_loss == 'soft_l1': epsilon = 0.001 rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon) elif self.rgb_loss =='mse': rgb_loss = F.mse_loss(rgb_pr, rgb_gt, reduction='none') else: raise NotImplementedError rgb_loss = torch.mean(rgb_loss) eikonal_loss = torch.sum(render_outputs['gradient_error'] * render_outputs['inner_mask']) / torch.sum(render_outputs['inner_mask'] + 1e-5) loss = rgb_loss * self.lambda_rgb_loss + eikonal_loss * self.lambda_eikonal_loss loss_batch = { 'eikonal': eikonal_loss, 'rendering': rgb_loss, # 'mask': mask_loss, } if self.lambda_mask_loss>0 and self.use_mask: mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none').mean() loss += mask_loss * self.lambda_mask_loss loss_batch['mask'] = mask_loss return loss, loss_batch class NeRFRenderer(BaseRenderer): def __init__(self, train_batch_num, test_batch_num, bound=0.5, use_mask=False, lambda_rgb_loss=1.0, lambda_mask_loss=0.0): super().__init__(train_batch_num, test_batch_num) self.train_batch_num = train_batch_num self.test_batch_num = test_batch_num self.use_mask = use_mask self.field = NGPNetwork(bound=bound) self.update_interval = 16 self.fp16 = True self.lambda_rgb_loss = lambda_rgb_loss self.lambda_mask_loss = lambda_mask_loss def render_impl(self, ray_batch, is_train, step): rays_o, rays_d = ray_batch['rays_o'], ray_batch['rays_d'] with torch.cuda.amp.autocast(enabled=self.fp16): if step % self.update_interval==0: self.field.update_extra_state() outputs = self.field.render(rays_o, rays_d,) renderings={ 'rgb': outputs['image'], 'depth': outputs['depth'], 'mask': outputs['weights_sum'].unsqueeze(-1), } return renderings def render_with_loss(self, ray_batch, is_train, step): render_outputs = self.render(ray_batch, is_train, step) rgb_gt = ray_batch['rgb'] rgb_pr = render_outputs['rgb'] epsilon = 0.001 rgb_loss = torch.sqrt(torch.sum((rgb_gt - rgb_pr) ** 2, dim=-1) + epsilon) rgb_loss = torch.mean(rgb_loss) loss = rgb_loss * self.lambda_rgb_loss loss_batch = {'rendering': rgb_loss} if self.use_mask: mask_loss = F.mse_loss(render_outputs['mask'], ray_batch['mask'], reduction='none') mask_loss = torch.mean(mask_loss) loss = loss + mask_loss * self.lambda_mask_loss loss_batch['mask'] = mask_loss return loss, loss_batch def cartesian_to_spherical(xyz): ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2 z = np.sqrt(xy + xyz[:, 2] ** 2) theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down # ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up azimuth = np.arctan2(xyz[:, 1], xyz[:, 0]) return np.array([theta, azimuth, z]) def get_pose(target_RT): R, T = target_RT[:3, :3], target_RT[:, -1] T_target = -R.T @ T theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) return theta_target, azimuth_target, z_target class RendererTrainer(pl.LightningModule): def __init__(self, image_path, data_path, total_steps, warm_up_steps, log_dir, train_batch_fg_num=0, use_cube_feats=False, cube_ckpt=None, cube_cfg=None, cube_bound=0.5, train_batch_num=4096, test_batch_num=8192, use_warm_up=True, use_mask=True, lambda_rgb_loss=1.0, lambda_mask_loss=0.0, renderer='neus', # used in neus lambda_eikonal_loss=0.1, coarse_sn=64, fine_sn=64): super().__init__() self.num_images = 36 # todo ours 36, syncdreamer 16 self.image_size = 256 self.log_dir = log_dir (Path(log_dir)/'images').mkdir(exist_ok=True, parents=True) self.train_batch_num = train_batch_num self.train_batch_fg_num = train_batch_fg_num self.test_batch_num = test_batch_num self.image_path = image_path self.data_path = data_path self.total_steps = total_steps self.warm_up_steps = warm_up_steps self.use_mask = use_mask self.lambda_eikonal_loss = lambda_eikonal_loss self.lambda_rgb_loss = lambda_rgb_loss self.lambda_mask_loss = lambda_mask_loss self.use_warm_up = use_warm_up self.use_cube_feats, self.cube_cfg, self.cube_ckpt = use_cube_feats, cube_cfg, cube_ckpt self._init_dataset() if renderer=='neus': self.renderer = NeuSRenderer(train_batch_num, test_batch_num, lambda_rgb_loss=lambda_rgb_loss, lambda_eikonal_loss=lambda_eikonal_loss, lambda_mask_loss=lambda_mask_loss, coarse_sn=coarse_sn, fine_sn=fine_sn) elif renderer=='ngp': self.renderer = NeRFRenderer(train_batch_num, test_batch_num, bound=cube_bound, use_mask=use_mask, lambda_mask_loss=lambda_mask_loss, lambda_rgb_loss=lambda_rgb_loss,) else: raise NotImplementedError self.validation_index = 0 def _construct_ray_batch(self, images_info): image_num = images_info['images'].shape[0] _, h, w, _ = images_info['images'].shape coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2 coords = coords.float()[None, :, :, :].repeat(image_num, 1, 1, 1) # imn,h,w,2 coords = coords.reshape(image_num, h * w, 2) coords = torch.cat([coords, torch.ones(image_num, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3 # imn,h*w,3 @ imn,3,3 => imn,h*w,3 rays_d = coords @ torch.inverse(images_info['Ks']).permute(0, 2, 1) poses = images_info['poses'] # imn,3,4 R, t = poses[:, :, :3], poses[:, :, 3:] rays_d = rays_d @ R rays_d = F.normalize(rays_d, dim=-1) rays_o = -R.permute(0,2,1) @ t # imn,3,3 @ imn,3,1 rays_o = rays_o.permute(0, 2, 1).repeat(1, h*w, 1) # imn,h*w,3 ray_batch = { 'rgb': images_info['images'].reshape(image_num*h*w,3), 'mask': images_info['masks'].reshape(image_num*h*w,1), 'rays_o': rays_o.reshape(image_num*h*w,3).float(), 'rays_d': rays_d.reshape(image_num*h*w,3).float(), } return ray_batch @staticmethod def load_model(cfg, ckpt): config = OmegaConf.load(cfg) model = instantiate_from_config(config.model) print(f'loading model from {ckpt} ...') ckpt = torch.load(ckpt) model.load_state_dict(ckpt['state_dict']) model = model.cuda().eval() return model def _init_dataset(self): mask_predictor = BackgroundRemoval() # syncdreamer fixed 16 views # self.K, self.azs, self.els, self.dists, self.poses = read_pickle(f'meta_info/camera-{self.num_images}.pkl') # for ours+NeuS, we pre fix 36 views self.K = np.array([[280.,0.,128.],[0.,280.,128.],[0.,0.,1.]], dtype=np.float32) data_dir = os.path.join(self.data_path, "mario/render_sync_36_single/model/") # fixed 36 views # get all files .npy self.azs = [] self.els = [] self.dists = [] self.poses = [] for index in range(self.num_images): pose = np.load(os.path.join(data_dir, "%03d.npy"%index))[:3, :] # in blender self.poses.append(pose) theta, azimuth, radius = get_pose(pose) self.azs.append(azimuth) self.els.append(theta) self.dists.append(radius) # stack to numpy along axis 0 self.azs = np.stack(self.azs, axis=0) # [25,] self.els = np.stack(self.els, axis=0) # [25,] self.dists = np.stack(self.dists, axis=0) # [25,] self.poses = np.stack(self.poses, axis=0) # [25, 3, 4] self.images_info = {'images': [] ,'masks': [], 'Ks': [], 'poses':[]} img = imread(self.image_path) for index in range(self.num_images): rgb = np.copy(img[:,index*self.image_size:(index+1)*self.image_size,:]) # predict mask if self.use_mask: imsave(f'{self.log_dir}/input-{index}.png', rgb) masked_image = mask_predictor(rgb) imsave(f'{self.log_dir}/masked-{index}.png', masked_image) mask = masked_image[:,:,3].astype(np.float32)/255 else: h, w, _ = rgb.shape mask = np.zeros([h,w], np.float32) rgb = rgb.astype(np.float32)/255 K, pose = np.copy(self.K), self.poses[index] self.images_info['images'].append(torch.from_numpy(rgb.astype(np.float32))) # h,w,3 self.images_info['masks'].append(torch.from_numpy(mask.astype(np.float32))) # h,w self.images_info['Ks'].append(torch.from_numpy(K.astype(np.float32))) self.images_info['poses'].append(torch.from_numpy(pose.astype(np.float32))) for k, v in self.images_info.items(): self.images_info[k] = torch.stack(v, 0) # stack all values self.train_batch = self._construct_ray_batch(self.images_info) self.train_batch_pseudo_fg = {} pseudo_fg_mask = torch.sum(self.train_batch['rgb']>0.99,1)!=3 for k, v in self.train_batch.items(): self.train_batch_pseudo_fg[k] = v[pseudo_fg_mask] self.train_ray_fg_num = int(torch.sum(pseudo_fg_mask).cpu().numpy()) self.train_ray_num = self.num_images * self.image_size ** 2 self._shuffle_train_batch() self._shuffle_train_fg_batch() def _shuffle_train_batch(self): self.train_batch_i = 0 shuffle_idxs = torch.randperm(self.train_ray_num, device='cpu') # shuffle for k, v in self.train_batch.items(): self.train_batch[k] = v[shuffle_idxs] def _shuffle_train_fg_batch(self): self.train_batch_fg_i = 0 shuffle_idxs = torch.randperm(self.train_ray_fg_num, device='cpu') # shuffle for k, v in self.train_batch_pseudo_fg.items(): self.train_batch_pseudo_fg[k] = v[shuffle_idxs] def training_step(self, batch, batch_idx): train_ray_batch = {k: v[self.train_batch_i:self.train_batch_i + self.train_batch_num].cuda() for k, v in self.train_batch.items()} self.train_batch_i += self.train_batch_num if self.train_batch_i + self.train_batch_num >= self.train_ray_num: self._shuffle_train_batch() if self.train_batch_fg_num>0: train_ray_batch_fg = {k: v[self.train_batch_fg_i:self.train_batch_fg_i+self.train_batch_fg_num].cuda() for k, v in self.train_batch_pseudo_fg.items()} self.train_batch_fg_i += self.train_batch_fg_num if self.train_batch_fg_i + self.train_batch_fg_num >= self.train_ray_fg_num: self._shuffle_train_fg_batch() for k, v in train_ray_batch_fg.items(): train_ray_batch[k] = torch.cat([train_ray_batch[k], v], 0) loss, loss_batch = self.renderer.render_with_loss(train_ray_batch, is_train=True, step=self.global_step) self.log_dict(loss_batch, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True) self.log('step', self.global_step, prog_bar=True, on_step=True, on_epoch=False, logger=False, rank_zero_only=True) lr = self.optimizers().param_groups[0]['lr'] self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True) return loss def _slice_images_info(self, index): return {k:v[index:index+1] for k, v in self.images_info.items()} @torch.no_grad() def validation_step(self, batch, batch_idx): with torch.no_grad(): if self.global_rank==0: # we output an rendering image images_info = self._slice_images_info(self.validation_index) self.validation_index += 1 self.validation_index %= self.num_images test_ray_batch = self._construct_ray_batch(images_info) test_ray_batch = {k: v.cuda() for k,v in test_ray_batch.items()} test_ray_batch['near'], test_ray_batch['far'] = near_far_from_sphere(test_ray_batch['rays_o'], test_ray_batch['rays_d']) render_outputs = self.renderer.render(test_ray_batch, False, self.global_step) process = lambda x: (x.cpu().numpy() * 255).astype(np.uint8) h, w = self.image_size, self.image_size rgb = torch.clamp(render_outputs['rgb'].reshape(h, w, 3), max=1.0, min=0.0) mask = torch.clamp(render_outputs['mask'].reshape(h, w, 1), max=1.0, min=0.0) mask_ = torch.repeat_interleave(mask, 3, dim=-1) output_image = concat_images_list(process(rgb), process(mask_)) if 'normal' in render_outputs: normal = torch.clamp((render_outputs['normal'].reshape(h, w, 3) + 1) / 2, max=1.0, min=0.0) normal = normal * mask # we only show foregound normal output_image = concat_images_list(output_image, process(normal)) # save images imsave(f'{self.log_dir}/images/{self.global_step}.jpg', output_image) def configure_optimizers(self): lr = self.learning_rate opt = torch.optim.AdamW([{"params": self.renderer.parameters(), "lr": lr},], lr=lr) def schedule_fn(step): total_step = self.total_steps warm_up_step = self.warm_up_steps warm_up_init = 0.02 warm_up_end = 1.0 final_lr = 0.02 interval = 1000 times = total_step // interval ratio = np.power(final_lr, 1/times) if step imn,h*w,3 rays_d = coords @ torch.inverse(K_).permute(0, 2, 1) R, t = pose_[:, :, :3], pose_[:, :, 3:] rays_d = rays_d @ R rays_d = F.normalize(rays_d, dim=-1) rays_o = -R.permute(0, 2, 1) @ t # imn,3,3 @ imn,3,1 rays_o = rays_o.permute(0, 2, 1).repeat(1, h * w, 1) # imn,h*w,3 ray_batch = { 'rays_o': rays_o.reshape(-1,3).cuda(), 'rays_d': rays_d.reshape(-1,3).cuda(), } with torch.no_grad(): image = model.renderer.render(ray_batch,False,5000)['rgb'].reshape(h,w,3) image = (image.cpu().numpy() * 255).astype(np.uint8) imgs.append(image) imageio.mimsave(f'{output}/rendering.mp4', imgs, fps=30) def extract_fields(bound_min, bound_max, resolution, query_func, batch_size=64, outside_val=1.0): N = batch_size X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) u = np.zeros([resolution, resolution, resolution], dtype=np.float32) with torch.no_grad(): for xi, xs in enumerate(X): for yi, ys in enumerate(Y): for zi, zs in enumerate(Z): xx, yy, zz = torch.meshgrid(xs, ys, zs) pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).cuda() val = query_func(pts).detach() outside_mask = torch.norm(pts,dim=-1)>=1.0 val[outside_mask]=outside_val val = val.reshape(len(xs), len(ys), len(zs)).cpu().numpy() u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val return u def extract_geometry(bound_min, bound_max, resolution, threshold, query_func, color_func, outside_val=1.0): u = extract_fields(bound_min, bound_max, resolution, query_func, outside_val=outside_val) vertices, triangles = mcubes.marching_cubes(u, threshold) b_max_np = bound_max.detach().cpu().numpy() b_min_np = bound_min.detach().cpu().numpy() vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] vertex_colors = color_func(vertices) return vertices, triangles, vertex_colors def extract_mesh(model, output, resolution=512): if not isinstance(model.renderer, NeuSRenderer): return bbox_min = -torch.ones(3)*DEFAULT_SIDE_LENGTH bbox_max = torch.ones(3)*DEFAULT_SIDE_LENGTH with torch.no_grad(): vertices, triangles, vertex_colors = extract_geometry(bbox_min, bbox_max, resolution, 0, lambda x: model.renderer.sdf_network.sdf(x), lambda x: model.renderer.get_vertex_colors(x)) # output geometry mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertex_colors) mesh.export(str(f'{output}/mesh.ply')) def main(): parser = argparse.ArgumentParser() parser.add_argument('-i', '--image_path', type=str, required=True) parser.add_argument('-n', '--name', type=str, required=True) parser.add_argument('-b', '--base', type=str, default='configs/neus.yaml') parser.add_argument('-d', '--data_path', type=str, default='/data/GSO/') parser.add_argument('-l', '--log', type=str, default='output/renderer') parser.add_argument('-s', '--seed', type=int, default=6033) parser.add_argument('-g', '--gpus', type=str, default='0,') parser.add_argument('-r', '--resume', action='store_true', default=False, dest='resume') parser.add_argument('--fp16', action='store_true', default=False, dest='fp16') opt = parser.parse_args() # seed_everything(opt.seed) # configs cfg = OmegaConf.load(opt.base) name = opt.name log_dir, ckpt_dir = Path(opt.log) / name, Path(opt.log) / name / 'ckpt' cfg.model.params['image_path'] = opt.image_path cfg.model.params['log_dir'] = log_dir cfg.model.params['data_path'] = opt.data_path # setup log_dir.mkdir(exist_ok=True, parents=True) ckpt_dir.mkdir(exist_ok=True, parents=True) trainer_config = cfg.trainer callback_config = cfg.callbacks model_config = cfg.model data_config = cfg.data data_config.params.seed = opt.seed data = instantiate_from_config(data_config) data.prepare_data() data.setup('fit') model = instantiate_from_config(model_config,) model.cpu() model.learning_rate = model_config.base_lr # logger logger = TensorBoardLogger(save_dir=log_dir, name='tensorboard_logs') callbacks=[] callbacks.append(LearningRateMonitor(logging_interval='step')) callbacks.append(ModelCheckpoint(dirpath=ckpt_dir, filename="{epoch:06}", verbose=True, save_last=True, every_n_train_steps=callback_config.save_interval)) # trainer trainer_config.update({ "accelerator": "cuda", "check_val_every_n_epoch": None, "benchmark": True, "num_sanity_val_steps": 0, "devices": 1, "gpus": opt.gpus, }) if opt.fp16: trainer_config['precision']=16 if opt.resume: callbacks.append(ResumeCallBacks()) trainer_config['resume_from_checkpoint'] = str(ckpt_dir / 'last.ckpt') else: if (ckpt_dir / 'last.ckpt').exists(): raise RuntimeError(f"checkpoint {ckpt_dir / 'last.ckpt'} existing ...") trainer = Trainer.from_argparse_args(args=argparse.Namespace(), **trainer_config, logger=logger, callbacks=callbacks) trainer.fit(model, data) model = model.cuda().eval() # render_images(model, log_dir) extract_mesh(model, log_dir) if __name__=="__main__": main() ================================================ FILE: 3drecon/util.py ================================================ import importlib import pickle import numpy as np import cv2 def instantiate_from_config(config): if not "target" in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def read_pickle(pkl_path): with open(pkl_path, 'rb') as f: return pickle.load(f) def output_points(fn,pts,colors=None): with open(fn, 'w') as f: for pi, pt in enumerate(pts): f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ') if colors is not None: f.write(f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}') f.write('\n') def concat_images(img0,img1,vert=False): if not vert: h0,h1=img0.shape[0],img1.shape[0], if h0 b (h w) c') image_embeddings = self.layernorm(image_embeddings) return image_embeddings ================================================ FILE: 4DoF/dataset.py ================================================ import os import math from pathlib import Path import torch import torchvision from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import numpy as np import webdataset as wds from torch.utils.data.distributed import DistributedSampler import matplotlib.pyplot as plt import sys class ObjaverseDataLoader(): def __init__(self, root_dir, batch_size, total_view=12, num_workers=4): self.root_dir = root_dir self.batch_size = batch_size self.num_workers = num_workers self.total_view = total_view image_transforms = [torchvision.transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] self.image_transforms = torchvision.transforms.Compose(image_transforms) def train_dataloader(self): dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False, image_transforms=self.image_transforms) # sampler = DistributedSampler(dataset) return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) # sampler=sampler) def val_dataloader(self): dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True, image_transforms=self.image_transforms) sampler = DistributedSampler(dataset) return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) def cartesian_to_spherical(xyz): ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2 z = np.sqrt(xy + xyz[:, 2] ** 2) theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down # ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up azimuth = np.arctan2(xyz[:, 1], xyz[:, 0]) return np.array([theta, azimuth, z]) def get_pose(target_RT): target_RT = target_RT[:3, :] R, T = target_RT[:3, :3], target_RT[:, -1] T_target = -R.T @ T theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) # assert if z_target is out of range if z_target.item() < 1.5 or z_target.item() > 2.2: # print('z_target out of range 1.5-2.2', z_target.item()) z_target = np.clip(z_target.item(), 1.5, 2.2) # with log scale for radius target_T = torch.tensor([theta_target.item(), azimuth_target.item(), (np.log(z_target.item()) - np.log(1.5))/(np.log(2.2)-np.log(1.5)) * torch.pi, torch.tensor(0)]) assert torch.all(target_T <= torch.pi) and torch.all(target_T >= -torch.pi) return target_T.numpy() class ObjaverseData(Dataset): def __init__(self, root_dir='.objaverse/hf-objaverse-v1/views', image_transforms=None, total_view=12, validation=False, T_in=1, T_out=1, fix_sample=False, ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) """ self.root_dir = Path(root_dir) self.total_view = total_view self.T_in = T_in self.T_out = T_out self.fix_sample = fix_sample self.paths = [] # # include all folders # for folder in os.listdir(self.root_dir): # if os.path.isdir(os.path.join(self.root_dir, folder)): # self.paths.append(folder) # load ids from .npy so we have exactly the same ids/order self.paths = np.load("../scripts/obj_ids.npy") # # only use 100K objects for ablation study # self.paths = self.paths[:100000] total_objects = len(self.paths) assert total_objects == 790152, 'total objects %d' % total_objects if validation: self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation else: self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training print('============= length of dataset %d =============' % len(self.paths)) self.tform = image_transforms downscale = 512 / 256. self.fx = 560. / downscale self.fy = 560. / downscale self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3) def __len__(self): return len(self.paths) def cartesian_to_spherical(self, xyz): ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2 z = np.sqrt(xy + xyz[:, 2] ** 2) theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down # ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up azimuth = np.arctan2(xyz[:, 1], xyz[:, 0]) return np.array([theta, azimuth, z]) def get_T(self, target_RT, cond_RT): R, T = target_RT[:3, :3], target_RT[:, -1] T_target = -R.T @ T R, T = cond_RT[:3, :3], cond_RT[:, -1] T_cond = -R.T @ T theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :]) theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) d_theta = theta_target - theta_cond d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) d_z = z_target - z_cond d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) return d_T def get_pose(self, target_RT): R, T = target_RT[:3, :3], target_RT[:, -1] T_target = -R.T @ T theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) # assert if z_target is out of range if z_target.item() < 1.5 or z_target.item() > 2.2: # print('z_target out of range 1.5-2.2', z_target.item()) z_target = np.clip(z_target.item(), 1.5, 2.2) # with log scale for radius target_T = torch.tensor([theta_target.item(), azimuth_target.item(), (np.log(z_target.item()) - np.log(1.5))/(np.log(2.2)-np.log(1.5)) * torch.pi, torch.tensor(0)]) assert torch.all(target_T <= torch.pi) and torch.all(target_T >= -torch.pi) return target_T def load_im(self, path, color): ''' replace background pixel with random color in rendering ''' try: img = plt.imread(path) except: print(path) sys.exit() img[img[:, :, -1] == 0.] = color img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) return img def __getitem__(self, index): data = {} total_view = 12 if self.fix_sample: if self.T_out > 1: indexes = range(total_view) index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):]) index_inputs = indexes[1:self.T_in+1] # one overlap identity else: indexes = range(total_view) index_targets = indexes[:self.T_out] index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity else: assert self.T_in + self.T_out <= total_view # training with replace, including identity indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True) index_inputs = indexes[:self.T_in] index_targets = indexes[self.T_in:] filename = os.path.join(self.root_dir, self.paths[index]) color = [1., 1., 1., 1.] try: input_ims = [] target_ims = [] target_Ts = [] cond_Ts = [] for i, index_input in enumerate(index_inputs): input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color)) input_ims.append(input_im) input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input)) cond_Ts.append(self.get_pose(input_RT)) for i, index_target in enumerate(index_targets): target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) target_ims.append(target_im) target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) target_Ts.append(self.get_pose(target_RT)) except: print('error loading data ', filename) filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid input_ims = [] target_ims = [] target_Ts = [] cond_Ts = [] # very hacky solution, sorry about this for i, index_input in enumerate(index_inputs): input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color)) input_ims.append(input_im) input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input)) cond_Ts.append(self.get_pose(input_RT)) for i, index_target in enumerate(index_targets): target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) target_ims.append(target_im) target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) target_Ts.append(self.get_pose(target_RT)) # stack to batch data['image_input'] = torch.stack(input_ims, dim=0) data['image_target'] = torch.stack(target_ims, dim=0) data['pose_out'] = torch.stack(target_Ts, dim=0) data['pose_in'] = torch.stack(cond_Ts, dim=0) return data def process_im(self, im): im = im.convert("RGB") return self.tform(im) ================================================ FILE: 4DoF/diffusers/__init__.py ================================================ __version__ = "0.18.2" from .configuration_utils import ConfigMixin from .utils import ( OptionalDependencyNotAvailable, is_flax_available, is_inflect_available, is_invisible_watermark_available, is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, is_note_seq_available, is_onnx_available, is_scipy_available, is_torch_available, is_torchsde_available, is_transformers_available, is_transformers_version, is_unidecode_available, logging, ) try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_onnx_objects import * # noqa F403 else: from .pipelines import OnnxRuntimeModel try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: from .models import ( AutoencoderKL, ControlNetModel, ModelMixin, PriorTransformer, T5FilmDecoder, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, UNet3DConditionModel, VQModel, ) from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, get_scheduler, ) from .pipelines import ( AudioPipelineOutput, ConsistencyModelPipeline, DanceDiffusionPipeline, DDIMPipeline, DDPMPipeline, DiffusionPipeline, DiTPipeline, ImagePipelineOutput, KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, ) from .schedulers import ( CMStochasticIterativeScheduler, DDIMInverseScheduler, DDIMParallelScheduler, DDIMScheduler, DDPMParallelScheduler, DDPMScheduler, DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, PNDMScheduler, RePaintScheduler, SchedulerMixin, ScoreSdeVeScheduler, UnCLIPScheduler, UniPCMultistepScheduler, VQDiffusionScheduler, ) from .training_utils import EMAModel try: if not (is_torch_available() and is_scipy_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_scipy_objects import * # noqa F403 else: from .schedulers import LMSDiscreteScheduler try: if not (is_torch_available() and is_torchsde_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_torchsde_objects import * # noqa F403 else: from .schedulers import DPMSolverSDEScheduler try: if not (is_torch_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipelines import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AudioLDMPipeline, CycleDiffusionPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, IFPipeline, IFSuperResolutionPipeline, ImageTextPipelineOutput, KandinskyImg2ImgPipeline, KandinskyInpaintPipeline, KandinskyPipeline, KandinskyPriorPipeline, KandinskyV22ControlnetImg2ImgPipeline, KandinskyV22ControlnetPipeline, KandinskyV22Img2ImgPipeline, KandinskyV22InpaintPipeline, KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorPipeline, LDMTextToImagePipeline, PaintByExamplePipeline, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, StableDiffusionLDM3DPipeline, StableDiffusionModelEditingPipeline, StableDiffusionPanoramaPipeline, StableDiffusionParadigmsPipeline, StableDiffusionPipeline, StableDiffusionPipelineSafe, StableDiffusionPix2PixZeroPipeline, StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, TextToVideoSDPipeline, TextToVideoZeroPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, VideoToVideoSDPipeline, VQDiffusionPipeline, ) try: if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 else: from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 else: from .pipelines import StableDiffusionKDiffusionPipeline try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 else: from .pipelines import ( OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionPipeline, OnnxStableDiffusionUpscalePipeline, StableDiffusionOnnxPipeline, ) try: if not (is_torch_available() and is_librosa_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_librosa_objects import * # noqa F403 else: from .pipelines import AudioDiffusionPipeline, Mel try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: from .pipelines import SpectrogramDiffusionPipeline try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_flax_objects import * # noqa F403 else: from .models.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxDPMSolverMultistepScheduler, FlaxKarrasVeScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, FlaxSchedulerMixin, FlaxScoreSdeVeScheduler, ) try: if not (is_flax_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 else: from .pipelines import ( FlaxStableDiffusionControlNetPipeline, FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, ) try: if not (is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_note_seq_objects import * # noqa F403 else: from .pipelines import MidiProcessor ================================================ FILE: 4DoF/diffusers/commands/__init__.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from argparse import ArgumentParser class BaseDiffusersCLICommand(ABC): @staticmethod @abstractmethod def register_subcommand(parser: ArgumentParser): raise NotImplementedError() @abstractmethod def run(self): raise NotImplementedError() ================================================ FILE: 4DoF/diffusers/commands/diffusers_cli.py ================================================ #!/usr/bin/env python # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from argparse import ArgumentParser from .env import EnvironmentCommand def main(): parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") # Register commands EnvironmentCommand.register_subcommand(commands_parser) # Let's go args = parser.parse_args() if not hasattr(args, "func"): parser.print_help() exit(1) # Run service = args.func(args) service.run() if __name__ == "__main__": main() ================================================ FILE: 4DoF/diffusers/commands/env.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import platform from argparse import ArgumentParser import huggingface_hub from .. import __version__ as version from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available from . import BaseDiffusersCLICommand def info_command_factory(_): return EnvironmentCommand() class EnvironmentCommand(BaseDiffusersCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): download_parser = parser.add_parser("env") download_parser.set_defaults(func=info_command_factory) def run(self): hub_version = huggingface_hub.__version__ pt_version = "not installed" pt_cuda_available = "NA" if is_torch_available(): import torch pt_version = torch.__version__ pt_cuda_available = torch.cuda.is_available() transformers_version = "not installed" if is_transformers_available(): import transformers transformers_version = transformers.__version__ accelerate_version = "not installed" if is_accelerate_available(): import accelerate accelerate_version = accelerate.__version__ xformers_version = "not installed" if is_xformers_available(): import xformers xformers_version = xformers.__version__ info = { "`diffusers` version": version, "Platform": platform.platform(), "Python version": platform.python_version(), "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", "Huggingface_hub version": hub_version, "Transformers version": transformers_version, "Accelerate version": accelerate_version, "xFormers version": xformers_version, "Using GPU in script?": "", "Using distributed or parallel set-up in script?": "", } print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") print(self.format_dict(info)) return info @staticmethod def format_dict(d): return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" ================================================ FILE: 4DoF/diffusers/configuration_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ ConfigMixin base class and utilities.""" import dataclasses import functools import importlib import inspect import json import os import re from collections import OrderedDict from pathlib import PosixPath from typing import Any, Dict, Tuple, Union import numpy as np from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError from . import __version__ from .utils import ( DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, extract_commit_hash, http_user_agent, logging, ) logger = logging.get_logger(__name__) _re_configuration_file = re.compile(r"config\.(.*)\.json") class FrozenDict(OrderedDict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for key, value in self.items(): setattr(self, key, value) self.__frozen = True def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") def setdefault(self, *args, **kwargs): raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") def pop(self, *args, **kwargs): raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") def update(self, *args, **kwargs): raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") def __setattr__(self, name, value): if hasattr(self, "__frozen") and self.__frozen: raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") super().__setattr__(name, value) def __setitem__(self, name, value): if hasattr(self, "__frozen") and self.__frozen: raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") super().__setitem__(name, value) class ConfigMixin: r""" Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and saving classes that inherit from [`ConfigMixin`]. Class attributes: - **config_name** (`str`) -- A filename under which the config should stored when calling [`~ConfigMixin.save_config`] (should be overridden by parent class). - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be overridden by subclass). - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by subclass). """ config_name = None ignore_for_config = [] has_compatibles = False _deprecated_kwargs = [] def register_to_config(self, **kwargs): if self.config_name is None: raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") # Special case for `kwargs` used in deprecation warning added to schedulers # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # or solve in a more general way. kwargs.pop("kwargs", None) if not hasattr(self, "_internal_dict"): internal_dict = kwargs else: previous_dict = dict(self._internal_dict) internal_dict = {**self._internal_dict, **kwargs} logger.debug(f"Updating config from {previous_dict} to {internal_dict}") self._internal_dict = FrozenDict(internal_dict) def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite: https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module """ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) is_attribute = name in self.__dict__ if is_in_config and not is_attribute: deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the [`~ConfigMixin.from_config`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file is saved (will be created if it does not exist). """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") os.makedirs(save_directory, exist_ok=True) # If we save using the predefined names, we can load using `from_config` output_config_file = os.path.join(save_directory, self.config_name) self.to_json_file(output_config_file) logger.info(f"Configuration saved in {output_config_file}") @classmethod def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): r""" Instantiate a Python class from a config dictionary. Parameters: config (`Dict[str, Any]`): A config dictionary from which the Python class is instantiated. Make sure to only load configuration files of compatible classes. return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether kwargs that are not consumed by the Python class should be returned or not. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it is loaded) and initiate the Python class. `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually overwrite the same named arguments in `config`. Returns: [`ModelMixin`] or [`SchedulerMixin`]: A model or scheduler object instantiated from a config dictionary. Examples: ```python >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler >>> # Download scheduler from huggingface.co and cache. >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") >>> # Instantiate DDIM scheduler class with same config as DDPM >>> scheduler = DDIMScheduler.from_config(scheduler.config) >>> # Instantiate PNDM scheduler class with same config as DDPM >>> scheduler = PNDMScheduler.from_config(scheduler.config) ``` """ # <===== TO BE REMOVED WITH DEPRECATION # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated if "pretrained_model_name_or_path" in kwargs: config = kwargs.pop("pretrained_model_name_or_path") if config is None: raise ValueError("Please make sure to provide a config as the first positional argument.") # ======> if not isinstance(config, dict): deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." if "Scheduler" in cls.__name__: deprecation_message += ( f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" " be removed in v1.0.0." ) elif "Model" in cls.__name__: deprecation_message += ( f"If you were trying to load a model, please use {cls}.load_config(...) followed by" f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" " instead. This functionality will be removed in v1.0.0." ) deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) # Allow dtype to be specified on initialization if "dtype" in unused_kwargs: init_dict["dtype"] = unused_kwargs.pop("dtype") # add possible deprecated kwargs for deprecated_kwarg in cls._deprecated_kwargs: if deprecated_kwarg in unused_kwargs: init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) # make sure to also save config parameters that might be used for compatible classes model.register_to_config(**hidden_dict) # add hidden kwargs of compatible classes to unused_kwargs unused_kwargs = {**unused_kwargs, **hidden_dict} if return_unused_kwargs: return (model, unused_kwargs) else: return model @classmethod def get_config_dict(cls, *args, **kwargs): deprecation_message = ( f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" " removed in version v1.0.0" ) deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) return cls.load_config(*args, **kwargs) @classmethod def load_config( cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, return_commit_hash=False, **kwargs, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: r""" Load a model or scheduler configuration. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with [`~ConfigMixin.save_config`]. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. return_unused_kwargs (`bool`, *optional*, defaults to `False): Whether unused keyword arguments of the config are returned. return_commit_hash (`bool`, *optional*, defaults to `False): Whether the `commit_hash` of the loaded configuration are returned. Returns: `dict`: A dictionary of all the parameters stored in a JSON configuration file. """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) _ = kwargs.pop("mirror", None) subfolder = kwargs.pop("subfolder", None) user_agent = kwargs.pop("user_agent", {}) user_agent = {**user_agent, "file_type": "config"} user_agent = http_user_agent(user_agent) pretrained_model_name_or_path = str(pretrained_model_name_or_path) if cls.config_name is None: raise ValueError( "`self.config_name` is not defined. Note that one should not load a config from " "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" ) if os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): # Load from a PyTorch checkpoint config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) elif subfolder is not None and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) ): config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) else: raise EnvironmentError( f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." ) else: try: # Load from URL or cache if already cached config_file = hf_hub_download( pretrained_model_name_or_path, filename=cls.config_name, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, subfolder=subfolder, revision=revision, ) except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" " login`." ) except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" " this model name. Check the model page at" f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." ) except HTTPError as err: raise EnvironmentError( "There was a specific connection error when trying to load" f" {pretrained_model_name_or_path}:\n{err}" ) except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" " run the library in offline mode at" " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a {cls.config_name} file" ) try: # Load config dict config_dict = cls._dict_from_json_file(config_file) commit_hash = extract_commit_hash(config_file) except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") if not (return_unused_kwargs or return_commit_hash): return config_dict outputs = (config_dict,) if return_unused_kwargs: outputs += (kwargs,) if return_commit_hash: outputs += (commit_hash,) return outputs @staticmethod def _get_init_keys(cls): return set(dict(inspect.signature(cls.__init__).parameters).keys()) @classmethod def extract_init_dict(cls, config_dict, **kwargs): # Skip keys that were not present in the original config, so default __init__ values were used used_defaults = config_dict.get("_use_default_values", []) config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} # 0. Copy origin config dict original_dict = dict(config_dict.items()) # 1. Retrieve expected config attributes from __init__ signature expected_keys = cls._get_init_keys(cls) expected_keys.remove("self") # remove general kwargs if present in dict if "kwargs" in expected_keys: expected_keys.remove("kwargs") # remove flax internal keys if hasattr(cls, "_flax_internal_args"): for arg in cls._flax_internal_args: expected_keys.remove(arg) # 2. Remove attributes that cannot be expected from expected config attributes # remove keys to be ignored if len(cls.ignore_for_config) > 0: expected_keys = expected_keys - set(cls.ignore_for_config) # load diffusers library to import compatible and original scheduler diffusers_library = importlib.import_module(__name__.split(".")[0]) if cls.has_compatibles: compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] else: compatible_classes = [] expected_keys_comp_cls = set() for c in compatible_classes: expected_keys_c = cls._get_init_keys(c) expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} # remove attributes from orig class that cannot be expected orig_cls_name = config_dict.pop("_class_name", cls.__name__) if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): orig_cls = getattr(diffusers_library, orig_cls_name) unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} # remove private attributes config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments init_dict = {} for key in expected_keys: # if config param is passed to kwarg and is present in config dict # it should overwrite existing config dict key if key in kwargs and key in config_dict: config_dict[key] = kwargs.pop(key) if key in kwargs: # overwrite key init_dict[key] = kwargs.pop(key) elif key in config_dict: # use value from config dict init_dict[key] = config_dict.pop(key) # 4. Give nice warning if unexpected values have been passed if len(config_dict) > 0: logger.warning( f"The config attributes {config_dict} were passed to {cls.__name__}, " "but are not expected and will be ignored. Please verify your " f"{cls.config_name} configuration file." ) # 5. Give nice info if config attributes are initiliazed to default because they have not been passed passed_keys = set(init_dict.keys()) if len(expected_keys - passed_keys) > 0: logger.info( f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." ) # 6. Define unused keyword arguments unused_kwargs = {**config_dict, **kwargs} # 7. Define "hidden" config parameters that were saved for compatible classes hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} return init_dict, unused_kwargs, hidden_config_dict @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() return json.loads(text) def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" @property def config(self) -> Dict[str, Any]: """ Returns the config of the class as a frozen dictionary Returns: `Dict[str, Any]`: Config of the class. """ return self._internal_dict def to_json_string(self) -> str: """ Serializes the configuration instance to a JSON string. Returns: `str`: String containing all the attributes that make up the configuration instance in JSON format. """ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} config_dict["_class_name"] = self.__class__.__name__ config_dict["_diffusers_version"] = __version__ def to_json_saveable(value): if isinstance(value, np.ndarray): value = value.tolist() elif isinstance(value, PosixPath): value = str(value) return value config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} # Don't save "_ignore_files" or "_use_default_values" config_dict.pop("_ignore_files", None) config_dict.pop("_use_default_values", None) return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ Save the configuration instance's parameters to a JSON file. Args: json_file_path (`str` or `os.PathLike`): Path to the JSON file to save a configuration instance's parameters. """ with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string()) def register_to_config(init): r""" Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be registered in the config, use the `ignore_for_config` class variable Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! """ @functools.wraps(init) def inner_init(self, *args, **kwargs): # Ignore private kwargs in the init. init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " "not inherit from `ConfigMixin`." ) ignore = getattr(self, "ignore_for_config", []) # Get positional arguments aligned with kwargs new_kwargs = {} signature = inspect.signature(init) parameters = { name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore } for arg, name in zip(args, parameters.keys()): new_kwargs[name] = arg # Then add all kwargs new_kwargs.update( { k: init_kwargs.get(k, default) for k, default in parameters.items() if k not in ignore and k not in new_kwargs } ) # Take note of the parameters that were not present in the loaded config if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) new_kwargs = {**config_init_kwargs, **new_kwargs} getattr(self, "register_to_config")(**new_kwargs) init(self, *args, **init_kwargs) return inner_init def flax_register_to_config(cls): original_init = cls.__init__ @functools.wraps(original_init) def init(self, *args, **kwargs): if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " "not inherit from `ConfigMixin`." ) # Ignore private kwargs in the init. Retrieve all passed attributes init_kwargs = dict(kwargs.items()) # Retrieve default values fields = dataclasses.fields(self) default_kwargs = {} for field in fields: # ignore flax specific attributes if field.name in self._flax_internal_args: continue if type(field.default) == dataclasses._MISSING_TYPE: default_kwargs[field.name] = None else: default_kwargs[field.name] = getattr(self, field.name) # Make sure init_kwargs override default kwargs new_kwargs = {**default_kwargs, **init_kwargs} # dtype should be part of `init_kwargs`, but not `new_kwargs` if "dtype" in new_kwargs: new_kwargs.pop("dtype") # Get positional arguments aligned with kwargs for i, arg in enumerate(args): name = fields[i].name new_kwargs[name] = arg # Take note of the parameters that were not present in the loaded config if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) getattr(self, "register_to_config")(**new_kwargs) original_init(self, *args, **kwargs) cls.__init__ = init return cls ================================================ FILE: 4DoF/diffusers/dependency_versions_check.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from .dependency_versions_table import deps from .utils.versions import require_version, require_version_core # define which module versions we always want to check at run time # (usually the ones defined in `install_requires` in setup.py) # # order specific notes: # - tqdm must be checked before tokenizers pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() if sys.version_info < (3, 7): pkgs_to_check_at_runtime.append("dataclasses") if sys.version_info < (3, 8): pkgs_to_check_at_runtime.append("importlib_metadata") for pkg in pkgs_to_check_at_runtime: if pkg in deps: if pkg == "tokenizers": # must be loaded here, or else tqdm check may fail from .utils import is_tokenizers_available if not is_tokenizers_available(): continue # not required, check version only if installed require_version_core(deps[pkg]) else: raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") def dep_version_check(pkg, hint=None): require_version(deps[pkg], hint) ================================================ FILE: 4DoF/diffusers/dependency_versions_table.py ================================================ # THIS FILE HAS BEEN AUTOGENERATED. To update: # 1. modify the `_deps` dict in setup.py # 2. run `make deps_table_update`` deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", "compel": "compel==0.1.8", "black": "black~=23.1", "datasets": "datasets", "filelock": "filelock", "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", "huggingface-hub": "huggingface-hub>=0.13.2", "requests-mock": "requests-mock==1.10.0", "importlib_metadata": "importlib_metadata", "invisible-watermark": "invisible-watermark", "isort": "isort>=5.5.4", "jax": "jax>=0.2.8,!=0.3.2", "jaxlib": "jaxlib>=0.1.65", "Jinja2": "Jinja2", "k-diffusion": "k-diffusion>=0.0.12", "torchsde": "torchsde", "note_seq": "note_seq", "librosa": "librosa", "numpy": "numpy", "omegaconf": "omegaconf", "parameterized": "parameterized", "protobuf": "protobuf>=3.20.3,<4", "pytest": "pytest", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", "ruff": "ruff>=0.0.241", "safetensors": "safetensors", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "scipy": "scipy", "onnx": "onnx", "regex": "regex!=2019.12.17", "requests": "requests", "tensorboard": "tensorboard", "torch": "torch>=1.4", "torchvision": "torchvision", "transformers": "transformers>=4.25.1", "urllib3": "urllib3<=2.0.0", } ================================================ FILE: 4DoF/diffusers/experimental/__init__.py ================================================ from .rl import ValueGuidedRLPipeline ================================================ FILE: 4DoF/diffusers/experimental/rl/__init__.py ================================================ from .value_guided_sampling import ValueGuidedRLPipeline ================================================ FILE: 4DoF/diffusers/experimental/rl/value_guided_sampling.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch import tqdm from ...models.unet_1d import UNet1DModel from ...pipelines import DiffusionPipeline from ...utils import randn_tensor from ...utils.dummy_pt_objects import DDPMScheduler class ValueGuidedRLPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Pipeline for sampling actions from a diffusion model trained to predict sequences of states. Original implementation inspired by this repository: https://github.com/jannerm/diffuser. Parameters: value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward. unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this application is [`DDPMScheduler`]. env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models. """ def __init__( self, value_function: UNet1DModel, unet: UNet1DModel, scheduler: DDPMScheduler, env, ): super().__init__() self.value_function = value_function self.unet = unet self.scheduler = scheduler self.env = env self.data = env.get_dataset() self.means = {} for key in self.data.keys(): try: self.means[key] = self.data[key].mean() except: # noqa: E722 pass self.stds = {} for key in self.data.keys(): try: self.stds[key] = self.data[key].std() except: # noqa: E722 pass self.state_dim = env.observation_space.shape[0] self.action_dim = env.action_space.shape[0] def normalize(self, x_in, key): return (x_in - self.means[key]) / self.stds[key] def de_normalize(self, x_in, key): return x_in * self.stds[key] + self.means[key] def to_torch(self, x_in): if type(x_in) is dict: return {k: self.to_torch(v) for k, v in x_in.items()} elif torch.is_tensor(x_in): return x_in.to(self.unet.device) return torch.tensor(x_in, device=self.unet.device) def reset_x0(self, x_in, cond, act_dim): for key, val in cond.items(): x_in[:, key, act_dim:] = val.clone() return x_in def run_diffusion(self, x, conditions, n_guide_steps, scale): batch_size = x.shape[0] y = None for i in tqdm.tqdm(self.scheduler.timesteps): # create batch of timesteps to pass into model timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) for _ in range(n_guide_steps): with torch.enable_grad(): x.requires_grad_() # permute to match dimension for pre-trained models y = self.value_function(x.permute(0, 2, 1), timesteps).sample grad = torch.autograd.grad([y.sum()], [x])[0] posterior_variance = self.scheduler._get_variance(i) model_std = torch.exp(0.5 * posterior_variance) grad = model_std * grad grad[timesteps < 2] = 0 x = x.detach() x = x + scale * grad x = self.reset_x0(x, conditions, self.action_dim) prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) # TODO: verify deprecation of this kwarg x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] # apply conditions to the trajectory (set the initial state) x = self.reset_x0(x, conditions, self.action_dim) x = self.to_torch(x) return x, y def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): # normalize the observations and create batch dimension obs = self.normalize(obs, "observations") obs = obs[None].repeat(batch_size, axis=0) conditions = {0: self.to_torch(obs)} shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) # generate initial noise and apply our conditions (to make the trajectories start at current state) x1 = randn_tensor(shape, device=self.unet.device) x = self.reset_x0(x1, conditions, self.action_dim) x = self.to_torch(x) # run the diffusion process x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) # sort output trajectories by value sorted_idx = y.argsort(0, descending=True).squeeze() sorted_values = x[sorted_idx] actions = sorted_values[:, :, : self.action_dim] actions = actions.detach().cpu().numpy() denorm_actions = self.de_normalize(actions, key="actions") # select the action with the highest value if y is not None: selected_index = 0 else: # if we didn't run value guiding, select a random action selected_index = np.random.randint(0, batch_size) denorm_actions = denorm_actions[selected_index, 0] return denorm_actions ================================================ FILE: 4DoF/diffusers/image_processor.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import List, Optional, Union import numpy as np import PIL import torch from PIL import Image from .configuration_utils import ConfigMixin, register_to_config from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate class VaeImageProcessor(ConfigMixin): """ Image processor for VAE. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. resample (`str`, *optional*, defaults to `lanczos`): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. do_convert_rgb (`bool`, *optional*, defaults to be `False`): Whether to convert the images to RGB format. """ config_name = CONFIG_NAME @register_to_config def __init__( self, do_resize: bool = True, vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = True, do_convert_rgb: bool = False, ): super().__init__() @staticmethod def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images @staticmethod def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: """ Convert a PIL image or a list of PIL images to NumPy arrays. """ if not isinstance(images, list): images = [images] images = [np.array(image).astype(np.float32) / 255.0 for image in images] images = np.stack(images, axis=0) return images @staticmethod def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: """ Convert a NumPy image to a PyTorch tensor. """ if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images @staticmethod def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: """ Convert a PyTorch tensor to a NumPy image. """ images = images.cpu().permute(0, 2, 3, 1).float().numpy() return images @staticmethod def normalize(images): """ Normalize an image array to [-1,1]. """ return 2.0 * images - 1.0 @staticmethod def denormalize(images): """ Denormalize an image array to [0,1]. """ return (images / 2 + 0.5).clamp(0, 1) @staticmethod def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: """ Converts an image to RGB format. """ image = image.convert("RGB") return image def resize( self, image: PIL.Image.Image, height: Optional[int] = None, width: Optional[int] = None, ) -> PIL.Image.Image: """ Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`. """ if height is None: height = image.height if width is None: width = image.width width, height = ( x - x % self.config.vae_scale_factor for x in (width, height) ) # resize to integer multiple of vae_scale_factor image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) return image def preprocess( self, image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], height: Optional[int] = None, width: Optional[int] = None, ) -> torch.Tensor: """ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. """ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) if isinstance(image, supported_formats): image = [image] elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" ) if isinstance(image[0], PIL.Image.Image): if self.config.do_convert_rgb: image = [self.convert_to_rgb(i) for i in image] if self.config.do_resize: image = [self.resize(i, height, width) for i in image] image = self.pil_to_numpy(image) # to np image = self.numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = self.numpy_to_pt(image) _, _, height, width = image.shape if self.config.do_resize and ( height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 ): raise ValueError( f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}" f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" ) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) _, channel, height, width = image.shape # don't need any preprocess if the image is latents if channel == 4: return image if self.config.do_resize and ( height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 ): raise ValueError( f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}" f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" ) # expected range [0,1], normalize to [-1,1] do_normalize = self.config.do_normalize if image.min() < 0: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", FutureWarning, ) do_normalize = False if do_normalize: image = self.normalize(image) return image def postprocess( self, image: torch.FloatTensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, ): if not isinstance(image, torch.Tensor): raise ValueError( f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" ) if output_type not in ["latent", "pt", "np", "pil"]: deprecation_message = ( f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " "`pil`, `np`, `pt`, `latent`" ) deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) output_type = "np" if output_type == "latent": return image if do_denormalize is None: do_denormalize = [self.config.do_normalize] * image.shape[0] image = torch.stack( [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] ) if output_type == "pt": return image image = self.pt_to_numpy(image) if output_type == "np": return image if output_type == "pil": return self.numpy_to_pil(image) class VaeImageProcessorLDM3D(VaeImageProcessor): """ Image processor for VAE LDM3D. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. resample (`str`, *optional*, defaults to `lanczos`): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. """ config_name = CONFIG_NAME @register_to_config def __init__( self, do_resize: bool = True, vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = True, ): super().__init__() @staticmethod def numpy_to_pil(images): """ Convert a NumPy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image[:, :, :3]) for image in images] return pil_images @staticmethod def rgblike_to_depthmap(image): """ Args: image: RGB-like depth image Returns: depth map """ return image[:, :, 1] * 2**8 + image[:, :, 2] def numpy_to_depth(self, images): """ Convert a NumPy depth image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images_depth = images[:, :, :, 3:] if images.shape[-1] == 6: images_depth = (images_depth * 255).round().astype("uint8") pil_images = [ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth ] elif images.shape[-1] == 4: images_depth = (images_depth * 65535.0).astype(np.uint16) pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth] else: raise Exception("Not supported") return pil_images def postprocess( self, image: torch.FloatTensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, ): if not isinstance(image, torch.Tensor): raise ValueError( f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" ) if output_type not in ["latent", "pt", "np", "pil"]: deprecation_message = ( f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " "`pil`, `np`, `pt`, `latent`" ) deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) output_type = "np" if do_denormalize is None: do_denormalize = [self.config.do_normalize] * image.shape[0] image = torch.stack( [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] ) image = self.pt_to_numpy(image) if output_type == "np": if image.shape[-1] == 6: image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0) else: image_depth = image[:, :, :, 3:] return image[:, :, :, :3], image_depth if output_type == "pil": return self.numpy_to_pil(image), self.numpy_to_depth(image) else: raise Exception(f"This type {output_type} is not supported") ================================================ FILE: 4DoF/diffusers/loaders.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import warnings from collections import defaultdict from pathlib import Path from typing import Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from .models.attention_processor import ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, LoRAAttnAddedKVProcessor, LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, SlicedAttnAddedKVProcessor, XFormersAttnProcessor, ) from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, TEXT_ENCODER_ATTN_MODULE, _get_model_file, deprecate, is_safetensors_available, is_transformers_available, logging, ) if is_safetensors_available(): import safetensors if is_transformers_available(): from transformers import PreTrainedModel, PreTrainedTokenizer logger = logging.get_logger(__name__) TEXT_ENCODER_NAME = "text_encoder" UNET_NAME = "unet" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" class AttnProcsLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): super().__init__() self.layers = torch.nn.ModuleList(state_dict.values()) self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} # .processor for unet, .self_attn for text encoder self.split_keys = [".processor", ".self_attn"] # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` def map_to(module, state_dict, *args, **kwargs): new_state_dict = {} for key, value in state_dict.items(): num = int(key.split(".")[1]) # 0 is always "layers" new_key = key.replace(f"layers.{num}", module.mapping[num]) new_state_dict[new_key] = value return new_state_dict def remap_key(key, state_dict): for k in self.split_keys: if k in key: return key.split(k)[0] + k raise ValueError( f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." ) def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) for key in all_keys: replace_key = remap_key(key, state_dict) new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") state_dict[new_key] = state_dict[key] del state_dict[key] self._register_state_dict_hook(map_to) self._register_load_state_dict_pre_hook(map_from, with_module=True) class UNet2DConditionLoadersMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be defined in [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) and be a `torch.nn.Module` class. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a directory (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except IOError as e: if not allow_pickle: raise e # try loading non-safetensors weights pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path_or_dict # fill attn processors attn_processors = {} is_lora = all("lora" in k for k in state_dict.keys()) is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: is_new_lora_format = all( key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() ) if is_new_lora_format: # Strip the `"unet"` prefix. is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys()) if is_text_encoder_present: warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)." warnings.warn(warn_message) unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)] state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} lora_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in lora_grouped_dict.items(): rank = value_dict["to_k_lora.down.weight"].shape[0] hidden_size = value_dict["to_k_lora.up.weight"].shape[0] attn_processor = self for sub_key in key.split("."): attn_processor = getattr(attn_processor, sub_key) if isinstance( attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) ): cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1] attn_processor_class = LoRAAttnAddedKVProcessor else: cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): attn_processor_class = LoRAXFormersAttnProcessor else: attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) attn_processors[key] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank, network_alpha=network_alpha, ) attn_processors[key].load_state_dict(value_dict) elif is_custom_diffusion: custom_diffusion_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): if len(value) == 0: custom_diffusion_grouped_dict[key] = {} else: if "to_out" in key: attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) else: attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:]) custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in custom_diffusion_grouped_dict.items(): if len(value_dict) == 0: attn_processors[key] = CustomDiffusionAttnProcessor( train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None ) else: cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False attn_processors[key] = CustomDiffusionAttnProcessor( train_kv=True, train_q_out=train_q_out, hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, ) attn_processors[key].load_state_dict(value_dict) else: raise ValueError( f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." ) # set correct dtype & device attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} # set layers self.set_attn_processor(attn_processors) def save_attn_procs( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = False, **kwargs, ): r""" Save an attention processor to a directory so that it can be reloaded using the [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save an attention processor to. Will be created if it doesn't exist. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. """ weight_name = weight_name or deprecate( "weights_name", "0.20.0", "`weights_name` is deprecated, please use `weight_name` instead.", take_from=kwargs, ) if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return if save_function is None: if safe_serialization: def save_function(weights, filename): return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) else: save_function = torch.save os.makedirs(save_directory, exist_ok=True) is_custom_diffusion = any( isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) for (_, x) in self.attn_processors.items() ) if is_custom_diffusion: model_to_save = AttnProcsLayers( { y: x for (y, x) in self.attn_processors.items() if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) } ) state_dict = model_to_save.state_dict() for name, attn in self.attn_processors.items(): if len(attn.state_dict()) == 0: state_dict[name] = {} else: model_to_save = AttnProcsLayers(self.attn_processors) state_dict = model_to_save.state_dict() if weight_name is None: if safe_serialization: weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE else: weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME # Save the model save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") class TextualInversionLoaderMixin: r""" Load textual inversion tokens and embeddings to the tokenizer and text encoder. """ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): r""" Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual inversion token or if the textual inversion token is a single vector, the input prompt is returned. Parameters: prompt (`str` or list of `str`): The prompt or prompts to guide the image generation. tokenizer (`PreTrainedTokenizer`): The tokenizer responsible for encoding the prompt into input tokens. Returns: `str` or list of `str`: The converted prompt """ if not isinstance(prompt, List): prompts = [prompt] else: prompts = prompt prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts] if not isinstance(prompt, List): return prompts[0] return prompts def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): r""" Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds to a multi-vector textual inversion embedding, this function will process the prompt so that the special token is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. Parameters: prompt (`str`): The prompt to guide the image generation. tokenizer (`PreTrainedTokenizer`): The tokenizer responsible for encoding the prompt into input tokens. Returns: `str`: The converted prompt """ tokens = tokenizer.tokenize(prompt) unique_tokens = set(tokens) for token in unique_tokens: if token in tokenizer.added_tokens_encoder: replacement = token i = 1 while f"{token}_{i}" in tokenizer.added_tokens_encoder: replacement += f" {token}_{i}" i += 1 prompt = prompt.replace(token, replacement) return prompt def load_textual_inversion( self, pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], token: Optional[Union[str, List[str]]] = None, **kwargs, ): r""" Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and Automatic1111 formats are supported). Parameters: pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`): Can be either one of the following or a list of them: - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual inversion weights. - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). token (`str` or `List[str]`, *optional*): Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a list, then `token` must also be a list of equal length. weight_name (`str`, *optional*): Name of a custom weight file. This should be used when: - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight name such as `text_inv.bin`. - The saved textual inversion file is in the Automatic1111 format. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. Example: To load a textual inversion embedding vector in 🤗 Diffusers format: ```py from diffusers import StableDiffusionPipeline import torch model_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") pipe.load_textual_inversion("sd-concepts-library/cat-toy") prompt = "A backpack" image = pipe(prompt, num_inference_steps=50).images[0] image.save("cat-backpack.png") ``` To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector locally: ```py from diffusers import StableDiffusionPipeline import torch model_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2") prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details." image = pipe(prompt, num_inference_steps=50).images[0] image.save("character.png") ``` """ if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): raise ValueError( f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" f" `{self.load_textual_inversion.__name__}`" ) if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): raise ValueError( f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" f" `{self.load_textual_inversion.__name__}`" ) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True user_agent = { "file_type": "text_inversion", "framework": "pytorch", } if not isinstance(pretrained_model_name_or_path, list): pretrained_model_name_or_paths = [pretrained_model_name_or_path] else: pretrained_model_name_or_paths = pretrained_model_name_or_path if isinstance(token, str): tokens = [token] elif token is None: tokens = [None] * len(pretrained_model_name_or_paths) else: tokens = token if len(pretrained_model_name_or_paths) != len(tokens): raise ValueError( f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}" f"Make sure both lists have the same length." ) valid_tokens = [t for t in tokens if t is not None] if len(set(valid_tokens)) < len(valid_tokens): raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") token_ids_and_embeddings = [] for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): if not isinstance(pretrained_model_name_or_path, dict): # 1. Load textual inversion file model_file = None # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except Exception as e: if not allow_pickle: raise e model_file = None if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=weight_name or TEXT_INVERSION_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path # 2. Load token and embedding correcly from file loaded_token = None if isinstance(state_dict, torch.Tensor): if token is None: raise ValueError( "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." ) embedding = state_dict elif len(state_dict) == 1: # diffusers loaded_token, embedding = next(iter(state_dict.items())) elif "string_to_param" in state_dict: # A1111 loaded_token = state_dict["name"] embedding = state_dict["string_to_param"]["*"] if token is not None and loaded_token != token: logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") else: token = loaded_token embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) # 3. Make sure we don't mess up the tokenizer or text encoder vocab = self.tokenizer.get_vocab() if token in vocab: raise ValueError( f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." ) elif f"{token}_1" in vocab: multi_vector_tokens = [token] i = 1 while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: multi_vector_tokens.append(f"{token}_{i}") i += 1 raise ValueError( f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." ) is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 if is_multi_vector: tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] embeddings = [e for e in embedding] # noqa: C416 else: tokens = [token] embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] # add tokens and get ids self.tokenizer.add_tokens(tokens) token_ids = self.tokenizer.convert_tokens_to_ids(tokens) token_ids_and_embeddings += zip(token_ids, embeddings) logger.info(f"Loaded textual inversion embedding for {token}.") # resize token embeddings and set all new embeddings self.text_encoder.resize_token_embeddings(len(self.tokenizer)) for token_id, embedding in token_ids_and_embeddings: self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding class LoraLoaderMixin: r""" Load LoRA layers into [`UNet2DConditionModel`] and [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). """ text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) # set lora scale to a reasonable default self._lora_scale = 1.0 if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except IOError as e: if not allow_pickle: raise e # try loading non-safetensors weights pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path_or_dict # Convert kohya-ss Style LoRA attn procs to diffusers attn procs network_alpha = None if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys): # Load the layers corresponding to UNet. unet_keys = [k for k in keys if k.startswith(self.unet_name)] logger.info(f"Loading {self.unet_name}.") unet_lora_state_dict = { k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys } self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] text_encoder_lora_state_dict = { k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: logger.info(f"Loading {self.text_encoder_name}.") attn_procs_text_encoder = self._load_text_encoder_attn_procs( text_encoder_lora_state_dict, network_alpha=network_alpha ) self._modify_text_encoder(attn_procs_text_encoder) # save lora attn procs of text encoder so that it can be easily retrieved self._text_encoder_lora_attn_procs = attn_procs_text_encoder # Otherwise, we're dealing with the old format. This means the `state_dict` should only # contain the module names of the `unet` as its keys WITHOUT any prefix. elif not all( key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() ): self.unet.load_attn_procs(state_dict) warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warnings.warn(warn_message) @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 @property def text_encoder_lora_attn_procs(self): if hasattr(self, "_text_encoder_lora_attn_procs"): return self._text_encoder_lora_attn_procs return def _remove_text_encoder_monkey_patch(self): # Loop over the CLIPAttention module of text_encoder for name, attn_module in self.text_encoder.named_modules(): if name.endswith(TEXT_ENCODER_ATTN_MODULE): # Loop over the LoRA layers for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): # Retrieve the q/k/v/out projection of CLIPAttention module = attn_module.get_submodule(text_encoder_attr) if hasattr(module, "old_forward"): # restore original `forward` to remove monkey-patch module.forward = module.old_forward delattr(module, "old_forward") def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): r""" Monkey-patches the forward passes of attention modules of the text encoder. Parameters: attn_processors: Dict[str, `LoRAAttnProcessor`]: A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. """ # First, remove any monkey-patch that might have been applied before self._remove_text_encoder_monkey_patch() # Loop over the CLIPAttention module of text_encoder for name, attn_module in self.text_encoder.named_modules(): if name.endswith(TEXT_ENCODER_ATTN_MODULE): # Loop over the LoRA layers for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer. module = attn_module.get_submodule(text_encoder_attr) lora_layer = attn_processors[name].get_submodule(attn_proc_attr) # save old_forward to module that can be used to remove monkey-patch old_forward = module.old_forward = module.forward # create a new scope that locks in the old_forward, lora_layer value for each new_forward function # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 def make_new_forward(old_forward, lora_layer): def new_forward(x): result = old_forward(x) + self.lora_scale * lora_layer(x) return result return new_forward # Monkey-patch. module.forward = make_new_forward(old_forward, lora_layer) @property def _lora_attn_processor_attr_to_text_encoder_attr(self): return { "to_q_lora": "q_proj", "to_k_lora": "k_proj", "to_v_lora": "v_proj", "to_out_lora": "out_proj", } def _load_text_encoder_attn_procs( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs ): r""" Load pretrained attention processor layers for [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., `./my_model_directory/`. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `diffusers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. mirror (`str`, *optional*): Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. Returns: `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding [`LoRAAttnProcessor`]. It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except IOError as e: if not allow_pickle: raise e # try loading non-safetensors weights pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path_or_dict # fill attn processors attn_processors = {} is_lora = all("lora" in k for k in state_dict.keys()) if is_lora: lora_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in lora_grouped_dict.items(): rank = value_dict["to_k_lora.down.weight"].shape[0] cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] hidden_size = value_dict["to_k_lora.up.weight"].shape[0] attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) attn_processors[key] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank, network_alpha=network_alpha, ) attn_processors[key].load_state_dict(value_dict) else: raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") # set correct dtype & device attn_processors = { k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items() } return attn_processors @classmethod def save_lora_weights( self, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = False, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the UNet. text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return if save_function is None: if safe_serialization: def save_function(weights, filename): return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) else: save_function = torch.save os.makedirs(save_directory, exist_ok=True) # Create a flat dictionary. state_dict = {} if unet_lora_layers is not None: weights = ( unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers ) unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()} state_dict.update(unet_lora_state_dict) if text_encoder_lora_layers is not None: weights = ( text_encoder_lora_layers.state_dict() if isinstance(text_encoder_lora_layers, torch.nn.Module) else text_encoder_lora_layers ) text_encoder_lora_state_dict = { f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() } state_dict.update(text_encoder_lora_state_dict) # Save the model if weight_name is None: if safe_serialization: weight_name = LORA_WEIGHT_NAME_SAFE else: weight_name = LORA_WEIGHT_NAME save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") def _convert_kohya_lora_to_diffusers(self, state_dict): unet_state_dict = {} te_state_dict = {} network_alpha = None for key, value in state_dict.items(): if "lora_down" in key: lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" if lora_name_alpha in state_dict: alpha = state_dict[lora_name_alpha].item() if network_alpha is None: network_alpha = alpha elif network_alpha != alpha: raise ValueError("Network alpha is not consistent") if lora_name.startswith("lora_unet_"): diffusers_name = key.replace("lora_unet_", "").replace("_", ".") diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") diffusers_name = diffusers_name.replace("mid.block", "mid_block") diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") if "transformer_blocks" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") unet_state_dict[diffusers_name] = value unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] elif lora_name.startswith("lora_te_"): diffusers_name = key.replace("lora_te_", "").replace("_", ".") diffusers_name = diffusers_name.replace("text.model", "text_model") diffusers_name = diffusers_name.replace("self.attn", "self_attn") diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") if "self_attn" in diffusers_name: te_state_dict[diffusers_name] = value te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} new_state_dict = {**unet_state_dict, **te_state_dict} return new_state_dict, network_alpha class FromSingleFileMixin: """ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. """ @classmethod def from_ckpt(cls, *args, **kwargs): deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead." deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False) return cls.from_single_file(*args, **kwargs) @classmethod def from_single_file(cls, pretrained_model_link_or_path, **kwargs): r""" Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline is set in evaluation mode (`model.eval()`) by default. Parameters: pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A link to the `.ckpt` file (for example `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - A path to a *file* containing all pipeline weights. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to True, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. extract_ema (`bool`, *optional*, defaults to `False`): Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield higher quality images for inference. Non-EMA weights are usually better to continue finetuning. upcast_attention (`bool`, *optional*, defaults to `None`): Whether the attention computation should always be upcasted. image_size (`int`, *optional*, defaults to 512): The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable Diffusion v2 base model. Use 768 for Stable Diffusion v2. prediction_type (`str`, *optional*): The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2. num_in_channels (`int`, *optional*, defaults to `None`): The number of input channels. If `None`, it will be automatically inferred. scheduler_type (`str`, *optional*, defaults to `"pndm"`): Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", "ddim"]`. load_safety_checker (`bool`, *optional*, defaults to `True`): Whether to load the safety checker or not. text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): An instance of [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if needed. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (for example the pipeline components of the specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` method. See example below for more information. Examples: ```py >>> from diffusers import StableDiffusionPipeline >>> # Download pipeline from huggingface.co and cache. >>> pipeline = StableDiffusionPipeline.from_single_file( ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" ... ) >>> # Download pipeline from local file >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly") >>> # Enable float16 and move to GPU >>> pipeline = StableDiffusionPipeline.from_single_file( ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", ... torch_dtype=torch.float16, ... ) >>> pipeline.to("cuda") ``` """ # import here to avoid circular dependency from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) extract_ema = kwargs.pop("extract_ema", False) image_size = kwargs.pop("image_size", None) scheduler_type = kwargs.pop("scheduler_type", "pndm") num_in_channels = kwargs.pop("num_in_channels", None) upcast_attention = kwargs.pop("upcast_attention", None) load_safety_checker = kwargs.pop("load_safety_checker", True) prediction_type = kwargs.pop("prediction_type", None) text_encoder = kwargs.pop("text_encoder", None) tokenizer = kwargs.pop("tokenizer", None) torch_dtype = kwargs.pop("torch_dtype", None) use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) pipeline_name = cls.__name__ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] from_safetensors = file_extension == "safetensors" if from_safetensors and use_safetensors is False: raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") # TODO: For now we only support stable diffusion stable_unclip = None model_type = None controlnet = False if pipeline_name == "StableDiffusionControlNetPipeline": # Model type will be inferred from the checkpoint. controlnet = True elif "StableDiffusion" in pipeline_name: # Model type will be inferred from the checkpoint. pass elif pipeline_name == "StableUnCLIPPipeline": model_type = "FrozenOpenCLIPEmbedder" stable_unclip = "txt2img" elif pipeline_name == "StableUnCLIPImg2ImgPipeline": model_type = "FrozenOpenCLIPEmbedder" stable_unclip = "img2img" elif pipeline_name == "PaintByExamplePipeline": model_type = "PaintByExample" elif pipeline_name == "LDMTextToImagePipeline": model_type = "LDMTextToImage" else: raise ValueError(f"Unhandled pipeline class: {pipeline_name}") # remove huggingface url for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: if pretrained_model_link_or_path.startswith(prefix): pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained ckpt_path = Path(pretrained_model_link_or_path) if not ckpt_path.is_file(): # get repo_id and (potentially nested) file path of ckpt in repo repo_id = "/".join(ckpt_path.parts[:2]) file_path = "/".join(ckpt_path.parts[2:]) if file_path.startswith("blob/"): file_path = file_path[len("blob/") :] if file_path.startswith("main/"): file_path = file_path[len("main/") :] pretrained_model_link_or_path = hf_hub_download( repo_id, filename=file_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, force_download=force_download, ) pipe = download_from_original_stable_diffusion_ckpt( pretrained_model_link_or_path, pipeline_class=cls, model_type=model_type, stable_unclip=stable_unclip, controlnet=controlnet, from_safetensors=from_safetensors, extract_ema=extract_ema, image_size=image_size, scheduler_type=scheduler_type, num_in_channels=num_in_channels, upcast_attention=upcast_attention, load_safety_checker=load_safety_checker, prediction_type=prediction_type, text_encoder=text_encoder, tokenizer=tokenizer, ) if torch_dtype is not None: pipe.to(torch_dtype=torch_dtype) return pipe ================================================ FILE: 4DoF/diffusers/models/__init__.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..utils import is_flax_available, is_torch_available if is_torch_available(): from .autoencoder_kl import AutoencoderKL from .controlnet import ControlNetModel from .dual_transformer_2d import DualTransformer2DModel from .modeling_utils import ModelMixin from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel from .vq_model import VQModel if is_flax_available(): from .controlnet_flax import FlaxControlNetModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL ================================================ FILE: 4DoF/diffusers/models/activations.py ================================================ from torch import nn def get_activation(act_fn): if act_fn in ["swish", "silu"]: return nn.SiLU() elif act_fn == "mish": return nn.Mish() elif act_fn == "gelu": return nn.GELU() else: raise ValueError(f"Unsupported activation function: {act_fn}") ================================================ FILE: 4DoF/diffusers/models/attention.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import nn from ..utils import maybe_allow_in_graph from .activations import get_activation from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", final_dropout: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) elif self.use_ada_layer_norm_zero: self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, posemb: Optional = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ): # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) else: norm_hidden_states = self.norm1(hidden_states) cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, posemb=posemb, # todo in self attn, posemb shoule be [pose_in, pose_in]? **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states # 2. Cross-Attention if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, posemb=posemb, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: raise ValueError( f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], dim=self._chunk_dim, ) else: ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states return hidden_states class FeedForward(nn.Module): r""" A feed-forward layer. Parameters: dim (`int`): The number of channels in the input. dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. """ def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) if activation_fn == "gelu-approximate": act_fn = GELU(dim, inner_dim, approximate="tanh") elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": act_fn = ApproximateGELU(dim, inner_dim) self.net = nn.ModuleList([]) # project in self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states): for module in self.net: hidden_states = module(hidden_states) return hidden_states class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. """ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): super().__init__() self.proj = nn.Linear(dim_in, dim_out) self.approximate = approximate def gelu(self, gate): if gate.device.type != "mps": return F.gelu(gate, approximate=self.approximate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states = self.gelu(hidden_states) return hidden_states class GEGLU(nn.Module): r""" A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def gelu(self, gate): if gate.device.type != "mps": return F.gelu(gate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) class ApproximateGELU(nn.Module): """ The approximate form of Gaussian Error Linear Unit (GELU) For more details, see section 2: https://arxiv.org/abs/1606.08415 """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out) def forward(self, x): x = self.proj(x) return x * torch.sigmoid(1.702 * x) class AdaLayerNorm(nn.Module): """ Norm layer modified to incorporate timestep embeddings. """ def __init__(self, embedding_dim, num_embeddings): super().__init__() self.emb = nn.Embedding(num_embeddings, embedding_dim) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, embedding_dim * 2) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) def forward(self, x, timestep): emb = self.linear(self.silu(self.emb(timestep))) scale, shift = torch.chunk(emb, 2) x = self.norm(x) * (1 + scale) + shift return x class AdaLayerNormZero(nn.Module): """ Norm layer adaptive layer norm zero (adaLN-Zero). """ def __init__(self, embedding_dim, num_embeddings): super().__init__() self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) def forward(self, x, timestep, class_labels, hidden_dtype=None): emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class AdaGroupNorm(nn.Module): """ GroupNorm layer modified to incorporate timestep embeddings. """ def __init__( self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 ): super().__init__() self.num_groups = num_groups self.eps = eps if act_fn is None: self.act = None else: self.act = get_activation(act_fn) self.linear = nn.Linear(embedding_dim, out_dim * 2) def forward(self, x, emb): if self.act: emb = self.act(emb) emb = self.linear(emb) emb = emb[:, :, None, None] scale, shift = emb.chunk(2, dim=1) x = F.group_norm(x, self.num_groups, eps=self.eps) x = x * (1 + scale) + shift return x ================================================ FILE: 4DoF/diffusers/models/attention_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import math import flax.linen as nn import jax import jax.numpy as jnp def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" num_kv, num_heads, k_features = key.shape[-3:] v_features = value.shape[-1] key_chunk_size = min(key_chunk_size, num_kv) query = query / jnp.sqrt(k_features) @functools.partial(jax.checkpoint, prevent_cse=False) def summarize_chunk(query, key, value): attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) max_score = jnp.max(attn_weights, axis=-1, keepdims=True) max_score = jax.lax.stop_gradient(max_score) exp_weights = jnp.exp(attn_weights - max_score) exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) max_score = jnp.einsum("...qhk->...qh", max_score) return (exp_values, exp_weights.sum(axis=-1), max_score) def chunk_scanner(chunk_idx): # julienne key array key_chunk = jax.lax.dynamic_slice( operand=key, start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] ) # julienne value array value_chunk = jax.lax.dynamic_slice( operand=value, start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] ) return summarize_chunk(query, key_chunk, value_chunk) chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) global_max = jnp.max(chunk_max, axis=0, keepdims=True) max_diffs = jnp.exp(chunk_max - global_max) chunk_values *= jnp.expand_dims(max_diffs, axis=-1) chunk_weights *= max_diffs all_values = chunk_values.sum(axis=0) all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) return all_values / all_weights def jax_memory_efficient_attention( query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 ): r""" Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 https://github.com/AminRezaei0x443/memory-efficient-attention Args: query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): numerical precision for computation query_chunk_size (`int`, *optional*, defaults to 1024): chunk size to divide query array value must divide query_length equally without remainder key_chunk_size (`int`, *optional*, defaults to 4096): chunk size to divide key and value array value must divide key_value_length equally without remainder Returns: (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) """ num_q, num_heads, q_features = query.shape[-3:] def chunk_scanner(chunk_idx, _): # julienne query array query_chunk = jax.lax.dynamic_slice( operand=query, start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] ) return ( chunk_idx + query_chunk_size, # unused ignore it _query_chunk_attention( query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size ), ) _, res = jax.lax.scan( f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter ) return jnp.concatenate(res, axis=-3) # fuse the chunked result back class FlaxAttention(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 Parameters: query_dim (:obj:`int`): Input hidden states dimension heads (:obj:`int`, *optional*, defaults to 8): Number of heads dim_head (:obj:`int`, *optional*, defaults to 64): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ query_dim: int heads: int = 8 dim_head: int = 64 dropout: float = 0.0 use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): inner_dim = self.dim_head * self.heads self.scale = self.dim_head**-0.5 # Weights were exported with old names {to_q, to_k, to_v, to_out} self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") self.dropout_layer = nn.Dropout(rate=self.dropout) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def __call__(self, hidden_states, context=None, deterministic=True): context = hidden_states if context is None else context query_proj = self.query(hidden_states) key_proj = self.key(context) value_proj = self.value(context) query_states = self.reshape_heads_to_batch_dim(query_proj) key_states = self.reshape_heads_to_batch_dim(key_proj) value_states = self.reshape_heads_to_batch_dim(value_proj) if self.use_memory_efficient_attention: query_states = query_states.transpose(1, 0, 2) key_states = key_states.transpose(1, 0, 2) value_states = value_states.transpose(1, 0, 2) # this if statement create a chunk size for each layer of the unet # the chunk size is equal to the query_length dimension of the deepest layer of the unet flatten_latent_dim = query_states.shape[-3] if flatten_latent_dim % 64 == 0: query_chunk_size = int(flatten_latent_dim / 64) elif flatten_latent_dim % 16 == 0: query_chunk_size = int(flatten_latent_dim / 16) elif flatten_latent_dim % 4 == 0: query_chunk_size = int(flatten_latent_dim / 4) else: query_chunk_size = int(flatten_latent_dim) hidden_states = jax_memory_efficient_attention( query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 ) hidden_states = hidden_states.transpose(1, 0, 2) else: # compute attentions attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) attention_scores = attention_scores * self.scale attention_probs = nn.softmax(attention_scores, axis=2) # attend to values hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxBasicTransformerBlock(nn.Module): r""" A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: https://arxiv.org/abs/1706.03762 Parameters: dim (:obj:`int`): Inner hidden states dimension n_heads (:obj:`int`): Number of heads d_head (:obj:`int`): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate only_cross_attention (`bool`, defaults to `False`): Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 """ dim: int n_heads: int d_head: int dropout: float = 0.0 only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 use_memory_efficient_attention: bool = False def setup(self): # self attention (or cross_attention if only_cross_attention is True) self.attn1 = FlaxAttention( self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype ) # cross attention self.attn2 = FlaxAttention( self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype ) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states if self.only_cross_attention: hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) else: hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention residual = hidden_states hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) hidden_states = hidden_states + residual # feed forward residual = hidden_states hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxTransformer2DModel(nn.Module): r""" A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: https://arxiv.org/pdf/1506.02025.pdf Parameters: in_channels (:obj:`int`): Input number of channels n_heads (:obj:`int`): Number of heads d_head (:obj:`int`): Hidden states dimension inside each head depth (:obj:`int`, *optional*, defaults to 1): Number of transformers block dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate use_linear_projection (`bool`, defaults to `False`): tbd only_cross_attention (`bool`, defaults to `False`): tbd dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int n_heads: int d_head: int depth: int = 1 dropout: float = 0.0 use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 use_memory_efficient_attention: bool = False def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) inner_dim = self.n_heads * self.d_head if self.use_linear_projection: self.proj_in = nn.Dense(inner_dim, dtype=self.dtype) else: self.proj_in = nn.Conv( inner_dim, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.transformer_blocks = [ FlaxBasicTransformerBlock( inner_dim, self.n_heads, self.d_head, dropout=self.dropout, only_cross_attention=self.only_cross_attention, dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention, ) for _ in range(self.depth) ] if self.use_linear_projection: self.proj_out = nn.Dense(inner_dim, dtype=self.dtype) else: self.proj_out = nn.Conv( inner_dim, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height * width, channels) hidden_states = self.proj_in(hidden_states) else: hidden_states = self.proj_in(hidden_states) hidden_states = hidden_states.reshape(batch, height * width, channels) for transformer_block in self.transformer_blocks: hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) if self.use_linear_projection: hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, channels) else: hidden_states = hidden_states.reshape(batch, height, width, channels) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxFeedForward(nn.Module): r""" Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's [`FeedForward`] class, with the following simplifications: - The activation function is currently hardcoded to a gated linear unit from: https://arxiv.org/abs/2002.05202 - `dim_out` is equal to `dim`. - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`]. Parameters: dim (:obj:`int`): Inner hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): # The second linear layer needs to be called # net_2 for now to match the index of the Sequential layer self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) self.net_2 = nn.Dense(self.dim, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): hidden_states = self.net_0(hidden_states, deterministic=deterministic) hidden_states = self.net_2(hidden_states) return hidden_states class FlaxGEGLU(nn.Module): r""" Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim (:obj:`int`): Input hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): inner_dim = self.dim * 4 self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, deterministic=True): hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic) ================================================ FILE: 4DoF/diffusers/models/attention_processor.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Union import torch import torch.nn.functional as F from torch import nn from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_xformers_available(): import xformers import xformers.ops else: xformers = None # 4DoF CaPE import einops def rotate_every_two(x): x = einops.rearrange(x, '... (d j) -> ... d j', j=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return einops.rearrange(x, '... d j -> ... (d j)') def cape(x, p): d, l, n = x.shape[-1], p.shape[-2], p.shape[-1] assert d % (2 * n) == 0 m = einops.repeat(p, 'b l n -> b l (n k)', k=d // n) return m def cape_embed(p1, p2, qq, kk): """ Embed camera position encoding into attention map Args: p1: query pose b, l_q, pose_dim p2: key pose b, l_k, pose_dim qq: query feature map b, l_q, feature_dim kk: key feature map b, l_k, feature_dim Returns: cape embedded attention map b, l_q, l_k """ assert p1.shape[-1] == p2.shape[-1] assert qq.shape[-1] == kk.shape[-1] assert p1.shape[0] == p2.shape[0] == qq.shape[0] == kk.shape[0] assert p1.shape[1] == qq.shape[1] assert p2.shape[1] == kk.shape[1] m1 = cape(qq, p1) m2 = cape(kk, p2) q = (qq * m1.cos()) + (rotate_every_two(qq) * m1.sin()) k = (kk * m2.cos()) + (rotate_every_two(kk) * m2.sin()) return q, k @maybe_allow_in_graph class Attention(nn.Module): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, ): super().__init__() inner_dim = dim_head * heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention if self.added_kv_proj_dim is None and self.only_cross_attention: raise ValueError( "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." ) if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None if spatial_norm_dim is not None: self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) else: self.spatial_norm = None if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(cross_attention_dim) elif cross_attention_norm == "group_norm": if self.added_kv_proj_dim is not None: # The given `encoder_hidden_states` are initially of shape # (batch_size, seq_len, added_kv_proj_dim) before being projected # to (batch_size, seq_len, cross_attention_dim). The norm is applied # before the projection, so we need to use `added_kv_proj_dim` as # the number of channels for the group norm. norm_cross_num_channels = added_kv_proj_dim else: norm_cross_num_channels = cross_attention_dim self.norm_cross = nn.GroupNorm( num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True ) else: raise ValueError( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): is_lora = hasattr(self, "processor") and isinstance( self.processor, (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor), ) is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) ) is_added_kv_processor = hasattr(self, "processor") and isinstance( self.processor, ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, SlicedAttnAddedKVProcessor, XFormersAttnAddedKVProcessor, LoRAAttnAddedKVProcessor, ), ) if use_memory_efficient_attention_xformers: if is_added_kv_processor and (is_lora or is_custom_diffusion): raise NotImplementedError( f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}" ) if not is_xformers_available(): raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers" ), name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e if is_lora: # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? processor = LoRAXFormersAttnProcessor( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, attention_op=attention_op, ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) elif is_custom_diffusion: processor = CustomDiffusionXFormersAttnProcessor( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, attention_op=attention_op, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) elif is_added_kv_processor: # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP # which uses this type of cross attention ONLY because the attention mask of format # [0, ..., -10.000, ..., 0, ...,] is not supported # throw warning logger.info( "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." ) processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: if is_lora: attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) processor = attn_processor_class( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) elif is_custom_diffusion: processor = CustomDiffusionAttnProcessor( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_attention_slice(self, slice_size): if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") if slice_size is not None and self.added_kv_proj_dim is not None: processor = SlicedAttnAddedKVProcessor(slice_size) elif slice_size is not None: processor = SlicedAttnProcessor(slice_size) elif self.added_kv_proj_dim is not None: processor = AttnAddedKVProcessor() else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_processor(self, processor: "AttnProcessor"): # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( hasattr(self, "processor") and isinstance(self.processor, torch.nn.Module) and not isinstance(processor, torch.nn.Module) ): logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") self._modules.pop("processor") self.processor = processor def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) def batch_to_head_dim(self, tensor): head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def head_to_batch_dim(self, tensor, out_dim=3): head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3) if out_dim == 3: tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def get_attention_scores(self, query, key, attention_mask=None): dtype = query.dtype if self.upcast_attention: query = query.float() key = key.float() if attention_mask is None: baddbmm_input = torch.empty( query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device ) beta = 0 else: baddbmm_input = attention_mask beta = 1 attention_scores = torch.baddbmm( baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale, ) del baddbmm_input if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) del attention_scores attention_probs = attention_probs.to(dtype) return attention_probs def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): if batch_size is None: deprecate( "batch_size=None", "0.0.15", ( "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" " `prepare_attention_mask` when preparing the attention_mask." ), ) batch_size = 1 head_size = self.heads if attention_mask is None: return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: # we want to instead pad by (0, remaining_length), where remaining_length is: # remaining_length: int = target_length - current_length # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: attention_mask = attention_mask.repeat_interleave(head_size, dim=0) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.repeat_interleave(head_size, dim=1) return attention_mask def norm_encoder_hidden_states(self, encoder_hidden_states): assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" if isinstance(self.norm_cross, nn.LayerNorm): encoder_hidden_states = self.norm_cross(encoder_hidden_states) elif isinstance(self.norm_cross, nn.GroupNorm): # Group norm norms along the channels dimension and expects # input to be in the shape of (N, C, *). In this case, we want # to norm along the hidden dimension, so we need to move # (batch_size, sequence_length, hidden_size) -> # (batch_size, hidden_size, sequence_length) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) encoder_hidden_states = self.norm_cross(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) else: assert False return encoder_hidden_states class AttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class LoRALinearLayer(nn.Module): def __init__(self, in_features, out_features, rank=4, network_alpha=None): super().__init__() if rank > min(in_features, out_features): raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") self.down = nn.Linear(in_features, rank, bias=False) self.up = nn.Linear(rank, out_features, bias=False) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning self.network_alpha = network_alpha self.rank = rank nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) def forward(self, hidden_states): orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) if self.network_alpha is not None: up_hidden_states *= self.network_alpha / self.rank return up_hidden_states.to(orig_dtype) class LoRAAttnProcessor(nn.Module): r""" Processor for implementing the LoRA attention mechanism. Args: hidden_size (`int`, *optional*): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. network_alpha (`int`, *optional*): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CustomDiffusionAttnProcessor(nn.Module): r""" Processor for implementing attention for the Custom Diffusion method. Args: train_kv (`bool`, defaults to `True`): Whether to newly train the key and value matrices corresponding to the text features. train_q_out (`bool`, defaults to `True`): Whether to newly train query matrices corresponding to the latent image features. hidden_size (`int`, *optional*, defaults to `None`): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*, defaults to `None`): The number of channels in the `encoder_hidden_states`. out_bias (`bool`, defaults to `True`): Whether to include the bias parameter in `train_q_out`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. """ def __init__( self, train_kv=True, train_q_out=True, hidden_size=None, cross_attention_dim=None, out_bias=True, dropout=0.0, ): super().__init__() self.train_kv = train_kv self.train_q_out = train_q_out self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim # `_custom_diffusion` id for easy serialization and loading. if self.train_kv: self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) if self.train_q_out: self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) self.to_out_custom_diffusion = nn.ModuleList([]) self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: query = self.to_q_custom_diffusion(hidden_states) else: query = attn.to_q(hidden_states) if encoder_hidden_states is None: crossattn = False encoder_hidden_states = hidden_states else: crossattn = True if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.train_kv: key = self.to_k_custom_diffusion(encoder_hidden_states) value = self.to_v_custom_diffusion(encoder_hidden_states) else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if crossattn: detach = torch.ones_like(key) detach[:, :1, :] = detach[:, :1, :] * 0.0 key = detach * key + (1 - detach) * key.detach() value = detach * value + (1 - detach) * value.detach() query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) if self.train_q_out: # linear proj hidden_states = self.to_out_custom_diffusion[0](hidden_states) # dropout hidden_states = self.to_out_custom_diffusion[1](hidden_states) else: # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class AttnAddedKVProcessor: r""" Processor for performing attention-related computations with extra learnable key and value matrices for the text encoder. """ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class AttnAddedKVProcessor2_0: r""" Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra learnable key and value matrices for the text encoder. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query, out_dim=4) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class LoRAAttnAddedKVProcessor(nn.Module): r""" Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text encoder. Args: hidden_size (`int`, *optional*): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*, defaults to `None`): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora( encoder_hidden_states ) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora( encoder_hidden_states ) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states) value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class XFormersAttnAddedKVProcessor: r""" Processor for implementing memory efficient attention using xFormers. Args: attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class XFormersAttnProcessor: r""" Processor for implementing memory efficient attention using xFormers. Args: attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, posemb: Optional = None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) if posemb is not None: # turn 2d attention into multiview attention self_attn = encoder_hidden_states is None # check if self attn or cross attn p_out, p_in = posemb t_out, t_in = p_out.shape[1], p_in.shape[1] # t size hidden_states = einops.rearrange(hidden_states, '(b t_out) l d -> b (t_out l) d', t_out=t_out) batch_size, key_tokens, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # apply 4DoF CaPE, todo now only for xformer processor if posemb is not None: p_out = einops.repeat(p_out, 'b t_out d -> b (t_out l) d', l=query.shape[1]//t_out) # query shape if self_attn: p_in = p_out else: p_in = einops.repeat(p_in, 'b t_in d -> b (t_in l) d', l=key.shape[1] // t_in) # key shape query, key = cape_embed(p_out, p_in, query, key) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() # self-ttn (bm) l c x (bm) l c -> (bm) l c # cross-ttn (bm) l c x b (nl) c -> (bm) l c # reuse 2d attention for multiview attention # self-ttn b (ml) c x b (ml) c -> b (ml) c # cross-ttn b (ml) c x b (nl) c -> b (ml) c hidden_states = xformers.ops.memory_efficient_attention( # query: (bm) l c -> b (ml) c; key: b (nl) c query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if posemb is not None: # reshape back hidden_states = einops.rearrange(hidden_states, 'b (t_out l) d -> (b t_out) l d', t_out=t_out) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class LoRAXFormersAttnProcessor(nn.Module): r""" Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers. Args: hidden_size (`int`, *optional*): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. network_alpha (`int`, *optional*): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ def __init__( self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None ): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank self.attention_op = attention_op self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class LoRAAttnProcessor2_0(nn.Module): r""" Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product attention. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. network_alpha (`int`, *optional*): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CustomDiffusionXFormersAttnProcessor(nn.Module): r""" Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. Args: train_kv (`bool`, defaults to `True`): Whether to newly train the key and value matrices corresponding to the text features. train_q_out (`bool`, defaults to `True`): Whether to newly train query matrices corresponding to the latent image features. hidden_size (`int`, *optional*, defaults to `None`): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*, defaults to `None`): The number of channels in the `encoder_hidden_states`. out_bias (`bool`, defaults to `True`): Whether to include the bias parameter in `train_q_out`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__( self, train_kv=True, train_q_out=False, hidden_size=None, cross_attention_dim=None, out_bias=True, dropout=0.0, attention_op: Optional[Callable] = None, ): super().__init__() self.train_kv = train_kv self.train_q_out = train_q_out self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.attention_op = attention_op # `_custom_diffusion` id for easy serialization and loading. if self.train_kv: self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) if self.train_q_out: self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) self.to_out_custom_diffusion = nn.ModuleList([]) self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: query = self.to_q_custom_diffusion(hidden_states) else: query = attn.to_q(hidden_states) if encoder_hidden_states is None: crossattn = False encoder_hidden_states = hidden_states else: crossattn = True if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.train_kv: key = self.to_k_custom_diffusion(encoder_hidden_states) value = self.to_v_custom_diffusion(encoder_hidden_states) else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if crossattn: detach = torch.ones_like(key) detach[:, :1, :] = detach[:, :1, :] * 0.0 key = detach * key + (1 - detach) * key.detach() value = detach * value + (1 - detach) * value.detach() query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) if self.train_q_out: # linear proj hidden_states = self.to_out_custom_diffusion[0](hidden_states) # dropout hidden_states = self.to_out_custom_diffusion[1](hidden_states) else: # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class SlicedAttnProcessor: r""" Processor for implementing sliced attention. Args: slice_size (`int`, *optional*): The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and `attention_head_dim` must be a multiple of the `slice_size`. """ def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) for i in range(batch_size_attention // self.slice_size): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class SlicedAttnAddedKVProcessor: r""" Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder. Args: slice_size (`int`, *optional*): The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and `attention_head_dim` must be a multiple of the `slice_size`. """ def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) for i in range(batch_size_attention // self.slice_size): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states AttentionProcessor = Union[ AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor, SlicedAttnProcessor, AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0, XFormersAttnAddedKVProcessor, LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, LoRAAttnAddedKVProcessor, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, ] class SpatialNorm(nn.Module): """ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 """ def __init__( self, f_channels, zq_channels, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) def forward(self, f, zq): f_size = f.shape[-2:] zq = F.interpolate(zq, size=f_size, mode="nearest") norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f ================================================ FILE: 4DoF/diffusers/models/autoencoder_kl.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, apply_forward_hook from .attention_processor import AttentionProcessor, AttnProcessor from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder @dataclass class AutoencoderKLOutput(BaseOutput): """ Output of AutoencoderKL encoding method. Args: latent_dist (`DiagonalGaussianDistribution`): Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. `DiagonalGaussianDistribution` allows for sampling latents from the distribution. """ latent_dist: "DiagonalGaussianDistribution" class AutoencoderKL(ModelMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, ): super().__init__() # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, ) # pass init params to Decoder self.decoder = Decoder( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) self.use_slicing = False self.use_tiling = False # only relevant if vae tiling is enabled self.tile_sample_min_size = self.config.sample_size sample_size = ( self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (Encoder, Decoder)): module.gradient_checkpointing = value def enable_tiling(self, use_tiling: bool = True): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.use_tiling = use_tiling def disable_tiling(self): r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.enable_tiling(False) def enable_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True def disable_slicing(self): r""" Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): return self.tiled_encode(x, return_dict=return_dict) if self.use_slicing and x.shape[0] > 1: encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) z = self.post_quant_conv(z) dec = self.decoder(z) if not return_dict: return (dec,) return DecoderOutput(sample=dec) @apply_forward_hook def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: decoded = self._decode(z).sample if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) def blend_v(self, a, b, blend_extent): blend_extent = min(a.shape[2], b.shape[2], blend_extent) for y in range(blend_extent): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b def blend_h(self, a, b, blend_extent): blend_extent = min(a.shape[3], b.shape[3], blend_extent) for x in range(blend_extent): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the output, but they should be much less noticeable. Args: x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent # Split the image into 512x512 tiles and encode them separately. rows = [] for i in range(0, x.shape[2], overlap_size): row = [] for j in range(0, x.shape[3], overlap_size): tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = self.encoder(tile) tile = self.quant_conv(tile) row.append(tile) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) moments = torch.cat(result_rows, dim=2) posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: r""" Decode a batch of images using a tiled decoder. Args: z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) row_limit = self.tile_sample_min_size - blend_extent # Split z into overlapping 64x64 tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] for i in range(0, z.shape[2], overlap_size): row = [] for j in range(0, z.shape[3], overlap_size): tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] tile = self.post_quant_conv(tile) decoded = self.decoder(tile) row.append(decoded) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) dec = torch.cat(result_rows, dim=2) if not return_dict: return (dec,) return DecoderOutput(sample=dec) def forward( self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[DecoderOutput, torch.FloatTensor]: r""" Args: sample (`torch.FloatTensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec) ================================================ FILE: 4DoF/diffusers/models/controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) from .unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ControlNetOutput(BaseOutput): """ The output of [`ControlNetModel`]. Args: down_block_res_samples (`tuple[torch.Tensor]`): A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be used to condition the original UNet's downsampling activations. mid_down_block_re_sample (`torch.Tensor`): The activation of the midde block (the lowest sample resolution). Each tensor should be of shape `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. Output can be used to condition the original UNet's middle block activation. """ down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor class ControlNetConditioningEmbedding(nn.Module): """ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full model) to encode image-space conditions ... into feature maps ..." """ def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 96, 256), ): super().__init__() self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class ControlNetModel(ModelMixin, ConfigMixin): """ A ControlNet model. Args: in_channels (`int`, defaults to 4): The number of channels in the input sample. flip_sin_to_cos (`bool`, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, defaults to 2): The number of layers per block. downsample_padding (`int`, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, defaults to 1): The scale factor to use for the mid block. act_fn (`str`, defaults to "silu"): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If None, normalization and activation layers is skipped in post-processing. norm_eps (`float`, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): The dimension of the attention heads. use_linear_projection (`bool`, defaults to `False`): class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. num_class_embeds (`int`, *optional*, defaults to 0): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. upcast_attention (`bool`, defaults to `False`): resnet_time_scale_shift (`str`, defaults to `"default"`): Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `conditioning_embedding` layer. global_pool_conditions (`bool`, defaults to `False`): """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 4, conditioning_channels: int = 3, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, ): super().__init__() # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) for _ in range(layers_per_block): controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channel = block_out_channels[-1] controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock2DCrossAttn( in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, ): r""" Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied where applicable. """ controlnet = cls( in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, attention_head_dim=unet.config.attention_head_dim, num_attention_heads=unet.config.num_attention_heads, use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, ) if load_weights_from_unet: controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) if controlnet.class_embedding: controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) return controlnet @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.FloatTensor, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: """ The [`ControlNetModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor. timestep (`Union[torch.Tensor, float, int]`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. controlnet_cond (`torch.FloatTensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`): cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. guess_mode (`bool`, defaults to `False`): In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: [`~models.controlnet.ControlNetOutput`] **or** `tuple`: If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # check channel order channel_order = self.config.controlnet_conditioning_channel_order if channel_order == "rgb": # in rgb order by default ... elif channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) else: raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # 2. pre-process sample = self.conv_in(sample) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample = sample + controlnet_cond # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, ) # 5. Control net blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples ] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: return (down_block_res_samples, mid_block_res_sample) return ControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module ================================================ FILE: 4DoF/diffusers/models/controlnet_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .modeling_flax_utils import FlaxModelMixin from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, ) @flax.struct.dataclass class FlaxControlNetOutput(BaseOutput): """ The output of [`FlaxControlNetModel`]. Args: down_block_res_samples (`jnp.ndarray`): mid_block_res_sample (`jnp.ndarray`): """ down_block_res_samples: jnp.ndarray mid_block_res_sample: jnp.ndarray class FlaxControlNetConditioningEmbedding(nn.Module): conditioning_embedding_channels: int block_out_channels: Tuple[int] = (16, 32, 96, 256) dtype: jnp.dtype = jnp.float32 def setup(self): self.conv_in = nn.Conv( self.block_out_channels[0], kernel_size=(3, 3), padding=((1, 1), (1, 1)), dtype=self.dtype, ) blocks = [] for i in range(len(self.block_out_channels) - 1): channel_in = self.block_out_channels[i] channel_out = self.block_out_channels[i + 1] conv1 = nn.Conv( channel_in, kernel_size=(3, 3), padding=((1, 1), (1, 1)), dtype=self.dtype, ) blocks.append(conv1) conv2 = nn.Conv( channel_out, kernel_size=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)), dtype=self.dtype, ) blocks.append(conv2) self.blocks = blocks self.conv_out = nn.Conv( self.conditioning_embedding_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) def __call__(self, conditioning): embedding = self.conv_in(conditioning) embedding = nn.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = nn.silu(embedding) embedding = self.conv_out(embedding) return embedding @flax_register_to_config class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" A ControlNet model. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods implemented for all models (such as downloading or saving). This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its general usage and behavior. Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: sample_size (`int`, *optional*): The size of the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): The tuple of downsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int` or `Tuple[int]`, *optional*): The number of attention heads. cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `conditioning_embedding` layer. """ sample_size: int = 32 in_channels: int = 4 down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ) only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: Union[int, Tuple[int]] = 8 num_attention_heads: Optional[Union[int, Tuple[int]]] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 controlnet_conditioning_channel_order: str = "rgb" conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256) def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8) controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"] def setup(self): block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = self.num_attention_heads or self.attention_head_dim # input self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # time self.time_proj = FlaxTimesteps( block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift ) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=self.conditioning_embedding_out_channels, ) only_cross_attention = self.only_cross_attention if isinstance(only_cross_attention, bool): only_cross_attention = (only_cross_attention,) * len(self.down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(self.down_block_types) # down down_blocks = [] controlnet_down_blocks = [] output_channel = block_out_channels[0] controlnet_block = nn.Conv( output_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 if down_block_type == "CrossAttnDownBlock2D": down_block = FlaxCrossAttnDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, num_attention_heads=num_attention_heads[i], add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], dtype=self.dtype, ) else: down_block = FlaxDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) for _ in range(self.layers_per_block): controlnet_block = nn.Conv( output_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv( output_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) controlnet_down_blocks.append(controlnet_block) self.down_blocks = down_blocks self.controlnet_down_blocks = controlnet_down_blocks # mid mid_block_channel = block_out_channels[-1] self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=mid_block_channel, dropout=self.dropout, num_attention_heads=num_attention_heads[-1], use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) self.controlnet_mid_block = nn.Conv( mid_block_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) def __call__( self, sample, timesteps, encoder_hidden_states, controlnet_cond, conditioning_scale: float = 1.0, return_dict: bool = True, train: bool = False, ) -> Union[FlaxControlNetOutput, Tuple]: r""" Args: sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor conditioning_scale: (`float`) the scale factor for controlnet outputs return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ channel_order = self.controlnet_conditioning_channel_order if channel_order == "bgr": controlnet_cond = jnp.flip(controlnet_cond, axis=1) # 1. time if not isinstance(timesteps, jnp.ndarray): timesteps = jnp.array([timesteps], dtype=jnp.int32) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: timesteps = timesteps.astype(dtype=jnp.float32) timesteps = jnp.expand_dims(timesteps, 0) t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) # 2. pre-process sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1)) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample += controlnet_cond # 3. down down_block_res_samples = (sample,) for down_block in self.down_blocks: if isinstance(down_block, FlaxCrossAttnDownBlock2D): sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples # 4. mid sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) # 5. contronet blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample *= conditioning_scale if not return_dict: return (down_block_res_samples, mid_block_res_sample) return FlaxControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) ================================================ FILE: 4DoF/diffusers/models/cross_attention.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..utils import deprecate from .attention_processor import ( # noqa: F401 Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor2_0, LoRAAttnProcessor, LoRALinearLayer, LoRAXFormersAttnProcessor, SlicedAttnAddedKVProcessor, SlicedAttnProcessor, XFormersAttnProcessor, ) from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401 deprecate( "cross_attention", "0.20.0", "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.", standard_warn=False, ) AttnProcessor = AttentionProcessor class CrossAttention(Attention): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class CrossAttnProcessor(AttnProcessorRename): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class LoRACrossAttnProcessor(LoRAAttnProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class CrossAttnAddedKVProcessor(AttnAddedKVProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class XFormersCrossAttnProcessor(XFormersAttnProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class SlicedCrossAttnProcessor(SlicedAttnProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) ================================================ FILE: 4DoF/diffusers/models/dual_transformer_2d.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional from torch import nn from .transformer_2d import Transformer2DModel, Transformer2DModelOutput class DualTransformer2DModel(nn.Module): """ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. """ def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, ): super().__init__() self.transformers = nn.ModuleList( [ Transformer2DModel( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, in_channels=in_channels, num_layers=num_layers, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attention_bias=attention_bias, sample_size=sample_size, num_vector_embeds=num_vector_embeds, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, ) for _ in range(2) ] ) # Variables that can be set by a pipeline: # The ratio of transformer1 to transformer2's output states to be combined during inference self.mix_ratio = 0.5 # The shape of `encoder_hidden_states` is expected to be # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` self.condition_lengths = [77, 257] # Which transformer to use to encode which condition. # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` self.transformer_index_for_condition = [1, 0] def forward( self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, cross_attention_kwargs=None, return_dict: bool = True, ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. attention_mask (`torch.FloatTensor`, *optional*): Optional attention mask to be applied in Attention return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ input_states = hidden_states encoded_states = [] tokens_start = 0 # attention_mask is not used yet for i in range(2): # for each of the two transformers, pass the corresponding condition tokens condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] transformer_index = self.transformer_index_for_condition[i] encoded_state = self.transformers[transformer_index]( input_states, encoder_hidden_states=condition_state, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] encoded_states.append(encoded_state - input_states) tokens_start += self.condition_lengths[i] output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) output_states = output_states + input_states if not return_dict: return (output_states,) return Transformer2DModelOutput(sample=output_states) ================================================ FILE: 4DoF/diffusers/models/embeddings.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional import numpy as np import torch from torch import nn from .activations import get_activation def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class PatchEmbed(nn.Module): """2D Image to Patch Embedding""" def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) def forward(self, latent): latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: latent = self.norm(latent) return latent + self.pos_embed class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, ) return t_emb class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__( self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.log = log self.flip_sin_to_cos = flip_sin_to_cos if set_W_to_weight: # to delete later self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.weight = self.W def forward(self, x): if self.log: x = torch.log(x) x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi if self.flip_sin_to_cos: out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) else: out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out class ImagePositionalEmbeddings(nn.Module): """ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the height and width of the latent space. For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 For VQ-diffusion: Output vector embeddings are used as input for the transformer. Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. Args: num_embed (`int`): Number of embeddings for the latent pixels embeddings. height (`int`): Height of the latent image i.e. the number of height embeddings. width (`int`): Width of the latent image i.e. the number of width embeddings. embed_dim (`int`): Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. """ def __init__( self, num_embed: int, height: int, width: int, embed_dim: int, ): super().__init__() self.height = height self.width = width self.num_embed = num_embed self.embed_dim = embed_dim self.emb = nn.Embedding(self.num_embed, embed_dim) self.height_emb = nn.Embedding(self.height, embed_dim) self.width_emb = nn.Embedding(self.width, embed_dim) def forward(self, index): emb = self.emb(index) height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) # 1 x H x D -> 1 x H x 1 x D height_emb = height_emb.unsqueeze(2) width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) # 1 x W x D -> 1 x 1 x W x D width_emb = width_emb.unsqueeze(1) pos_emb = height_emb + width_emb # 1 x H x W x D -> 1 x L xD pos_emb = pos_emb.view(1, self.height * self.width, -1) emb = emb + pos_emb[:, : emb.shape[1], :] return emb class LabelEmbedding(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. Args: num_classes (`int`): The number of classes. hidden_size (`int`): The size of the vector embeddings. dropout_prob (`float`): The probability of dropping a label. """ def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): """ Drops labels to enable classifier-free guidance. """ if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob else: drop_ids = torch.tensor(force_drop_ids == 1) labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels: torch.LongTensor, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings class TextImageProjection(nn.Module): def __init__( self, text_embed_dim: int = 1024, image_embed_dim: int = 768, cross_attention_dim: int = 768, num_image_text_embeds: int = 10, ): super().__init__() self.num_image_text_embeds = num_image_text_embeds self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): batch_size = text_embeds.shape[0] # image image_text_embeds = self.image_embeds(image_embeds) image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) # text text_embeds = self.text_proj(text_embeds) return torch.cat([image_text_embeds, text_embeds], dim=1) class ImageProjection(nn.Module): def __init__( self, image_embed_dim: int = 768, cross_attention_dim: int = 768, num_image_text_embeds: int = 32, ): super().__init__() self.num_image_text_embeds = num_image_text_embeds self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.FloatTensor): batch_size = image_embeds.shape[0] # image image_embeds = self.image_embeds(image_embeds) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) image_embeds = self.norm(image_embeds) return image_embeds class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) def forward(self, timestep, class_labels, hidden_dtype=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) class_labels = self.class_embedder(class_labels) # (N, D) conditioning = timesteps_emb + class_labels # (N, D) return conditioning class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() self.norm1 = nn.LayerNorm(encoder_dim) self.pool = AttentionPooling(num_heads, encoder_dim) self.proj = nn.Linear(encoder_dim, time_embed_dim) self.norm2 = nn.LayerNorm(time_embed_dim) def forward(self, hidden_states): hidden_states = self.norm1(hidden_states) hidden_states = self.pool(hidden_states) hidden_states = self.proj(hidden_states) hidden_states = self.norm2(hidden_states) return hidden_states class TextImageTimeEmbedding(nn.Module): def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) self.text_norm = nn.LayerNorm(time_embed_dim) self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): # text time_text_embeds = self.text_proj(text_embeds) time_text_embeds = self.text_norm(time_text_embeds) # image time_image_embeds = self.image_proj(image_embeds) return time_image_embeds + time_text_embeds class ImageTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim) def forward(self, image_embeds: torch.FloatTensor): # image time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_norm(time_image_embeds) return time_image_embeds class ImageHintTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim) self.input_hint_block = nn.Sequential( nn.Conv2d(3, 16, 3, padding=1), nn.SiLU(), nn.Conv2d(16, 16, 3, padding=1), nn.SiLU(), nn.Conv2d(16, 32, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(32, 32, 3, padding=1), nn.SiLU(), nn.Conv2d(32, 96, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(96, 96, 3, padding=1), nn.SiLU(), nn.Conv2d(96, 256, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(256, 4, 3, padding=1), ) def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor): # image time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_norm(time_image_embeds) hint = self.input_hint_block(hint) return time_image_embeds, hint class AttentionPooling(nn.Module): # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 def __init__(self, num_heads, embed_dim, dtype=None): super().__init__() self.dtype = dtype self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.num_heads = num_heads self.dim_per_head = embed_dim // self.num_heads def forward(self, x): bs, length, width = x.size() def shape(x): # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, -1, self.num_heads, self.dim_per_head) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) x = x.transpose(1, 2) return x class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) # (bs*n_heads, class_token_length, dim_per_head) q = shape(self.q_proj(class_token)) # (bs*n_heads, length+class_token_length, dim_per_head) k = shape(self.k_proj(x)) v = shape(self.v_proj(x)) # (bs*n_heads, class_token_length, length+class_token_length): scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) # (bs*n_heads, dim_per_head, class_token_length) a = torch.einsum("bts,bcs->bct", weight, v) # (bs, length+1, width) a = a.reshape(bs, -1, 1).transpose(1, 2) return a[:, 0, :] # cls_token ================================================ FILE: 4DoF/diffusers/models/embeddings_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import flax.linen as nn import jax.numpy as jnp def get_sinusoidal_embeddings( timesteps: jnp.ndarray, embedding_dim: int, freq_shift: float = 1, min_timescale: float = 1, max_timescale: float = 1.0e4, flip_sin_to_cos: bool = False, scale: float = 1.0, ) -> jnp.ndarray: """Returns the positional encoding (same as Tensor2Tensor). Args: timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim: The number of output channels. min_timescale: The smallest time unit (should probably be 0.0). max_timescale: The largest time unit. Returns: a Tensor of timing signals [N, num_channels] """ assert timesteps.ndim == 1, "Timesteps should be a 1d-array" assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" num_timescales = float(embedding_dim // 2) log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) # scale embeddings scaled_time = scale * emb if flip_sin_to_cos: signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) else: signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. Args: time_embed_dim (`int`, *optional*, defaults to `32`): Time step embedding dimension dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ time_embed_dim: int = 32 dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, temb): temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) temb = nn.silu(temb) temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) return temb class FlaxTimesteps(nn.Module): r""" Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 Args: dim (`int`, *optional*, defaults to `32`): Time step embedding dimension """ dim: int = 32 flip_sin_to_cos: bool = False freq_shift: float = 1 @nn.compact def __call__(self, timesteps): return get_sinusoidal_embeddings( timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift ) ================================================ FILE: 4DoF/diffusers/models/modeling_flax_pytorch_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch - Flax general utilities.""" import re import jax.numpy as jnp from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey from ..utils import logging logger = logging.get_logger(__name__) def rename_key(key): regex = r"\w+[.]\d+" pats = re.findall(regex, key) for pat in pats: key = key.replace(pat, "_".join(pat.split("."))) return key ##################### # PyTorch => Flax # ##################### # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" # conv norm or layer norm renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) if ( any("norm" in str_ for str_ in pt_tuple_key) and (pt_tuple_key[-1] == "bias") and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) ): renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) return renamed_pt_tuple_key, pt_tensor elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) return renamed_pt_tuple_key, pt_tensor # embedding if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) return renamed_pt_tuple_key, pt_tensor # conv layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: pt_tensor = pt_tensor.transpose(2, 3, 1, 0) return renamed_pt_tuple_key, pt_tensor # linear layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if pt_tuple_key[-1] == "weight": pt_tensor = pt_tensor.T return renamed_pt_tuple_key, pt_tensor # old PyTorch layer norm weight renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) if pt_tuple_key[-1] == "gamma": return renamed_pt_tuple_key, pt_tensor # old PyTorch layer norm bias renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) if pt_tuple_key[-1] == "beta": return renamed_pt_tuple_key, pt_tensor return pt_tuple_key, pt_tensor def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): # Step 1: Convert pytorch tensor to numpy pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} # Step 2: Since the model is stateless, get random Flax params random_flax_params = flax_model.init_weights(PRNGKey(init_key)) random_flax_state_dict = flatten_dict(random_flax_params) flax_state_dict = {} # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): renamed_pt_key = rename_key(pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) # Correctly rename weight parameters flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) if flax_key in random_flax_state_dict: if flax_tensor.shape != random_flax_state_dict[flax_key].shape: raise ValueError( f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." ) # also add unexpected weight so that warning is thrown flax_state_dict[flax_key] = jnp.asarray(flax_tensor) return unflatten_dict(flax_state_dict) ================================================ FILE: 4DoF/diffusers/models/modeling_flax_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from pickle import UnpicklingError from typing import Any, Dict, Union import jax import jax.numpy as jnp import msgpack.exceptions from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError from .. import __version__, is_torch_available from ..utils import ( CONFIG_NAME, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging, ) from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax logger = logging.get_logger(__name__) class FlaxModelMixin: r""" Base class for all Flax models. [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and saving models. - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _flax_internal_args = ["name", "parent", "dtype"] @classmethod def _from_config(cls, config, **kwargs): """ All context managers that the model should be initialized under go here. """ return cls(config, **kwargs) def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: """ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. """ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 def conditional_cast(param): if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): param = param.astype(dtype) return param if mask is None: return jax.tree_map(conditional_cast, params) flat_params = flatten_dict(params) flat_mask, _ = jax.tree_flatten(mask) for masked, key in zip(flat_mask, flat_params.keys()): if masked: param = flat_params[key] flat_params[key] = conditional_cast(param) return unflatten_dict(flat_params) def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast the `params` in place. This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` for params you want to cast, and `False` for those you want to skip. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # load model >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision >>> params = model.to_bf16(params) >>> # If you don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } >>> mask = traverse_util.unflatten_dict(mask) >>> params = model.to_bf16(params, mask) ```""" return self._cast_floating_to(params, jnp.bfloat16, mask) def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` for params you want to cast, and `False` for those you want to skip. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model params will be in fp32, to illustrate the use of this method, >>> # we'll first cast to fp16 and back to fp32 >>> params = model.to_f16(params) >>> # now cast back to fp32 >>> params = model.to_fp32(params) ```""" return self._cast_floating_to(params, jnp.float32, mask) def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the `params` in place. This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full half-precision training or to save weights in float16 for inference in order to save memory and improve speed. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` for params you want to cast, and `False` for those you want to skip. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # load model >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model params will be in fp32, to cast these to float16 >>> params = model.to_fp16(params) >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } >>> mask = traverse_util.unflatten_dict(mask) >>> params = model.to_fp16(params, mask) ```""" return self._cast_floating_to(params, jnp.float16, mask) def init_weights(self, rng: jax.random.KeyArray) -> Dict: raise NotImplementedError(f"init_weights method has to be implemented for {self}") @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], dtype: jnp.dtype = jnp.float32, *model_args, **kwargs, ): r""" Instantiate a pretrained Flax model from a pretrained model configuration. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved using [`~FlaxModelMixin.save_pretrained`]. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified, all the computation will be performed with the given `dtype`. This only specifies the dtype of the *computation* and does not influence the dtype of model parameters. If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and [`~FlaxModelMixin.to_bf16`]. model_args (sequence of positional arguments, *optional*): All remaining positional arguments are passed to the underlying model's `__init__` method. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. from_pt (`bool`, *optional*, defaults to `False`): Load the model weights from a PyTorch checkpoint save file. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it is loaded) and initiate the model (for example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done). - If a configuration is not provided, `kwargs` are first passed to the configuration class initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds to a configuration attribute is used to override said attribute with the supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute are passed to the underlying model's `__init__` function. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co and cache. >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") ``` If you get the error message below, you need to finetune the weights for your downstream task: ```bash Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) from_pt = kwargs.pop("from_pt", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) user_agent = { "diffusers": __version__, "file_type": "model", "framework": "flax", } # Load config if we don't provide a configuration config_path = config if config is not None else pretrained_model_name_or_path model, model_kwargs = cls.from_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, # model args dtype=dtype, **kwargs, ) # Load model pretrained_path_with_subfolder = ( pretrained_model_name_or_path if subfolder is None else os.path.join(pretrained_model_name_or_path, subfolder) ) if os.path.isdir(pretrained_path_with_subfolder): if from_pt: if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): raise EnvironmentError( f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} " ) model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME) # Check if pytorch weights exist instead elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): raise EnvironmentError( f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model" " using `from_pt=True`." ) else: raise EnvironmentError( f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"{pretrained_path_with_subfolder}." ) else: try: model_file = hf_hub_download( pretrained_model_name_or_path, filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, subfolder=subfolder, revision=revision, ) except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " "login`." ) except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " "this model name. Check the model page at " f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}." ) except HTTPError as err: raise EnvironmentError( f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" f"{err}" ) except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your" " internet connection or see how to run the library in offline mode at" " 'https://huggingface.co/docs/transformers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." ) if from_pt: if is_torch_available(): from .modeling_utils import load_state_dict else: raise EnvironmentError( "Can't load the model in PyTorch format because PyTorch is not installed. " "Please, install PyTorch or use native Flax weights." ) # Step 1: Get the pytorch file pytorch_model_file = load_state_dict(model_file) # Step 2: Convert the weights state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) else: try: with open(model_file, "rb") as state_f: state = from_bytes(cls, state_f.read()) except (UnpicklingError, msgpack.exceptions.ExtraData) as e: try: with open(model_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please" " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" " folder you cloned." ) else: raise ValueError from e except (UnicodeDecodeError, ValueError): raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") # make sure all arrays are stored as jnp.ndarray # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) # flatten dicts state = flatten_dict(state) params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0)) required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) shape_state = flatten_dict(unfreeze(params_shape_tree)) missing_keys = required_params - set(state.keys()) unexpected_keys = set(state.keys()) - required_params if missing_keys: logger.warning( f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " "Make sure to call model.init_weights to initialize the missing weights." ) cls._missing_keys = missing_keys for key in state.keys(): if key in shape_state and state[key].shape != shape_state[key].shape: raise ValueError( f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " ) # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: del state[unexpected_key] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" " with another architecture." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" f" was trained on, you can already use {model.__class__.__name__} for predictions without further" " training." ) return model, unflatten_dict(state) def save_pretrained( self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict], is_main_process: bool = True, ): """ Save a model and its configuration file to a directory so that it can be reloaded using the [`~FlaxModelMixin.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save a model and its configuration file to. Will be created if it doesn't exist. params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) model_to_save = self # Attach architecture to the config # Save the config if is_main_process: model_to_save.save_config(save_directory) # save model output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) with open(output_model_file, "wb") as f: model_bytes = to_bytes(params) f.write(model_bytes) logger.info(f"Model weights saved in {output_model_file}") ================================================ FILE: 4DoF/diffusers/models/modeling_pytorch_flax_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch - Flax general utilities.""" from pickle import UnpicklingError import jax import jax.numpy as jnp import numpy as np from flax.serialization import from_bytes from flax.traverse_util import flatten_dict from ..utils import logging logger = logging.get_logger(__name__) ##################### # Flax => PyTorch # ##################### # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352 def load_flax_checkpoint_in_pytorch_model(pt_model, model_file): try: with open(model_file, "rb") as flax_state_f: flax_state = from_bytes(None, flax_state_f.read()) except UnpicklingError as e: try: with open(model_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please" " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" " folder you cloned." ) else: raise ValueError from e except (UnicodeDecodeError, ValueError): raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") return load_flax_weights_in_pytorch_model(pt_model, flax_state) def load_flax_weights_in_pytorch_model(pt_model, flax_state): """Load flax checkpoints in a PyTorch model""" try: import torch # noqa: F401 except ImportError: logger.error( "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see" " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" " instructions." ) raise # check if we have bf16 weights is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() if any(is_type_bf16): # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16 # and bf16 is not fully supported in PT yet. logger.warning( "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " "before loading those in PyTorch model." ) flax_state = jax.tree_util.tree_map( lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state ) pt_model.base_model_prefix = "" flax_state_dict = flatten_dict(flax_state, sep=".") pt_model_dict = pt_model.state_dict() # keep track of unexpected & missing keys unexpected_keys = [] missing_keys = set(pt_model_dict.keys()) for flax_key_tuple, flax_tensor in flax_state_dict.items(): flax_key_tuple_array = flax_key_tuple.split(".") if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4: flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) elif flax_key_tuple_array[-1] == "kernel": flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] flax_tensor = flax_tensor.T elif flax_key_tuple_array[-1] == "scale": flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] if "time_embedding" not in flax_key_tuple_array: for i, flax_key_tuple_string in enumerate(flax_key_tuple_array): flax_key_tuple_array[i] = ( flax_key_tuple_string.replace("_0", ".0") .replace("_1", ".1") .replace("_2", ".2") .replace("_3", ".3") .replace("_4", ".4") .replace("_5", ".5") .replace("_6", ".6") .replace("_7", ".7") .replace("_8", ".8") .replace("_9", ".9") ) flax_key = ".".join(flax_key_tuple_array) if flax_key in pt_model_dict: if flax_tensor.shape != pt_model_dict[flax_key].shape: raise ValueError( f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." ) else: # add weight to pytorch dict flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) # remove from missing keys missing_keys.remove(flax_key) else: # weight is not expected by PyTorch model unexpected_keys.append(flax_key) pt_model.load_state_dict(pt_model_dict) # re-transform missing_keys to list missing_keys = list(missing_keys) if len(unexpected_keys) > 0: logger.warning( "Some weights of the Flax model were not used when initializing the PyTorch model" f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" " FlaxBertForSequenceClassification model)." ) if len(missing_keys) > 0: logger.warning( f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" " use it for predictions and inference." ) return pt_model ================================================ FILE: 4DoF/diffusers/models/modeling_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import itertools import os import re from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor, device, nn from .. import __version__ from ..utils import ( CONFIG_NAME, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, deprecate, is_accelerate_available, is_safetensors_available, is_torch_version, logging, ) logger = logging.get_logger(__name__) if is_torch_version(">=", "1.9.0"): _LOW_CPU_MEM_USAGE_DEFAULT = True else: _LOW_CPU_MEM_USAGE_DEFAULT = False if is_accelerate_available(): import accelerate from accelerate.utils import set_module_tensor_to_device from accelerate.utils.versions import is_torch_version if is_safetensors_available(): import safetensors def get_parameter_device(parameter: torch.nn.Module): try: parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) return next(parameters_and_buffers).device except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].device def get_parameter_dtype(parameter: torch.nn.Module): try: params = tuple(parameter.parameters()) if len(params) > 0: return params[0].dtype buffers = tuple(parameter.buffers()) if len(buffers) > 0: return buffers[0].dtype except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].dtype def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ try: if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant): return torch.load(checkpoint_file, map_location="cpu") else: return safetensors.torch.load_file(checkpoint_file, device="cpu") except Exception as e: try: with open(checkpoint_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please install " "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " "you cloned." ) else: raise ValueError( f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " "model. Make sure you have saved the model properly." ) from e except (UnicodeDecodeError, ValueError): raise OSError( f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." ) def _load_state_dict_into_model(model_to_load, state_dict): # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it state_dict = state_dict.copy() error_msgs = [] # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module: torch.nn.Module, prefix=""): args = (state_dict, prefix, {}, True, [], [], error_msgs) module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(model_to_load) return error_msgs class ModelMixin(torch.nn.Module): r""" Base class for all models. [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and saving models. - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None def __init__(self): super().__init__() def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__': https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module """ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) is_attribute = name in self.__dict__ if is_in_config and not is_attribute: deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) return self._internal_dict[name] # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module return super().__getattr__(name) @property def is_gradient_checkpointing(self) -> bool: """ Whether gradient checkpointing is activated for this model or not. """ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) def enable_gradient_checkpointing(self): """ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). """ if not self._supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") self.apply(partial(self._set_gradient_checkpointing, value=True)) def disable_gradient_checkpointing(self): """ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). """ if self._supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) def set_use_memory_efficient_attention_xformers( self, valid: bool, attention_op: Optional[Callable] = None ) -> None: # Recursively walk through all the children. # Any children which exposes the set_use_memory_efficient_attention_xformers method # gets the message def fn_recursive_set_mem_eff(module: torch.nn.Module): if hasattr(module, "set_use_memory_efficient_attention_xformers"): module.set_use_memory_efficient_attention_xformers(valid, attention_op) for child in module.children(): fn_recursive_set_mem_eff(child) for module in self.children(): if isinstance(module, torch.nn.Module): fn_recursive_set_mem_eff(module) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): r""" Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed. ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent. Parameters: attention_op (`Callable`, *optional*): Override the default `None` operator for use as `op` argument to the [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) function of xFormers. Examples: ```py >>> import torch >>> from diffusers import UNet2DConditionModel >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp >>> model = UNet2DConditionModel.from_pretrained( ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 ... ) >>> model = model.to("cuda") >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) ``` """ self.set_use_memory_efficient_attention_xformers(True, attention_op) def disable_xformers_memory_efficient_attention(self): r""" Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). """ self.set_use_memory_efficient_attention_xformers(False) def save_pretrained( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Callable = None, safe_serialization: bool = False, variant: Optional[str] = None, ): """ Save a model and its configuration file to a directory so that it can be reloaded using the [`~models.ModelMixin.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save a model and its configuration file to. Will be created if it doesn't exist. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. """ if safe_serialization and not is_safetensors_available(): raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) model_to_save = self # Attach architecture to the config # Save the config if is_main_process: model_to_save.save_config(save_directory) # Save the model state_dict = model_to_save.state_dict() weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = _add_variant(weights_name, variant) # Save the model if safe_serialization: safetensors.torch.save_file( state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"} ) else: torch.save(state_dict, os.path.join(save_directory, weights_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a pretrained PyTorch model from a pretrained model configuration. The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To train the model, set it back in training mode with `model.train()`. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`~ModelMixin.save_pretrained`]. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info (`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. from_flax (`bool`, *optional*, defaults to `False`): Load the model weights from a Flax checkpoint save file. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if `device_map` contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. variant (`str`, *optional*): Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. You can also activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. Example: ```py from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") ``` If you get the error message below, you need to finetune the weights for your downstream task: ```bash Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " install accelerate\n```\n." ) if device_map is not None and not is_accelerate_available(): raise NotImplementedError( "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" " `device_map=None`. You can install accelerate with `pip install accelerate`." ) # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) if low_cpu_mem_usage is False and device_map is not None: raise ValueError( f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } # load config config, unused_kwargs, commit_hash = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, device_map=device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, user_agent=user_agent, **kwargs, ) # load model model_file = None if from_flax: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=FLAX_WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) model = cls.from_config(config, **unused_kwargs) # Convert the weights from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model model = load_flax_checkpoint_in_pytorch_model(model, model_file) else: if use_safetensors: try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) except IOError as e: if not allow_pickle: raise e pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) if low_cpu_mem_usage: # Instantiate model with empty weights with accelerate.init_empty_weights(): model = cls.from_config(config, **unused_kwargs) # if device_map is None, load the state dict and move the params from meta device to the cpu if device_map is None: param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) if len(missing_keys) > 0: raise ValueError( f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " those weights or else make sure your checkpoint file is correct." ) unexpected_keys = [] empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): accepts_dtype = "dtype" in set( inspect.signature(set_module_tensor_to_device).parameters.keys() ) if param_name not in empty_state_dict: unexpected_keys.append(param_name) continue if empty_state_dict[param_name].shape != param.shape: raise ValueError( f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) if accepts_dtype: set_module_tensor_to_device( model, param_name, param_device, value=param, dtype=torch_dtype ) else: set_module_tensor_to_device(model, param_name, param_device, value=param) if cls._keys_to_ignore_on_load_unexpected is not None: for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: logger.warn( f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU try: accelerate.load_checkpoint_and_dispatch( model, model_file, device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, ) except AttributeError as e: # When using accelerate loading, we do not have the ability to load the state # dict and rename the weight names manually. Additionally, accelerate skips # torch loading conventions and directly writes into `module.{_buffers, _parameters}` # (which look like they should be private variables?), so we can't use the standard hooks # to rename parameters on load. We need to mimic the original weight names so the correct # attributes are available. After we have loaded the weights, we convert the deprecated # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert # the weights so we don't have to do this again. if "'Attention' object has no attribute" in str(e): logger.warn( f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," " please also re-upload it or open a PR on the original repository." ) model._temp_convert_self_to_deprecated_attention_blocks() accelerate.load_checkpoint_and_dispatch( model, model_file, device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, ) model._undo_temp_convert_self_to_deprecated_attention_blocks() else: raise e loading_info = { "missing_keys": [], "unexpected_keys": [], "mismatched_keys": [], "error_msgs": [], } else: model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, state_dict, model_file, pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, ) loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) elif torch_dtype is not None: model = model.to(torch_dtype) model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: return model, loading_info return model @classmethod def _load_pretrained_model( cls, model, state_dict, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes=False, ): # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() loaded_keys = list(state_dict.keys()) expected_keys = list(model_state_dict.keys()) original_loaded_keys = loaded_keys missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) # Make sure we are able to load base models as well as derived models (with heads) model_to_load = model def _find_mismatched_keys( state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: model_key = checkpoint_key if ( model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape ): mismatched_keys.append( (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) ) del state_dict[checkpoint_key] return mismatched_keys if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, ) error_msgs = _load_state_dict_into_model(model_to_load, state_dict) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: error_msg += ( "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" " BertForPreTraining model).\n- This IS NOT expected if you are initializing" f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" " identical (initializing a BertForSequenceClassification model from a" " BertForSequenceClassification model)." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" " without further training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" " able to use it for predictions and inference." ) return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs @property def device(self) -> device: """ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). """ return get_parameter_device(self) @property def dtype(self) -> torch.dtype: """ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ return get_parameter_dtype(self) def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: """ Get number of (trainable or non-embedding) parameters in the module. Args: only_trainable (`bool`, *optional*, defaults to `False`): Whether or not to return only the number of trainable parameters. exclude_embeddings (`bool`, *optional*, defaults to `False`): Whether or not to return only the number of non-embedding parameters. Returns: `int`: The number of parameters. Example: ```py from diffusers import UNet2DConditionModel model_id = "runwayml/stable-diffusion-v1-5" unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") unet.num_parameters(only_trainable=True) 859520964 ``` """ if exclude_embeddings: embedding_param_names = [ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, torch.nn.Embedding) ] non_embedding_parameters = [ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names ] return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) def _convert_deprecated_attention_blocks(self, state_dict): deprecated_attention_block_paths = [] def recursive_find_attn_block(name, module): if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: deprecated_attention_block_paths.append(name) for sub_name, sub_module in module.named_children(): sub_name = sub_name if name == "" else f"{name}.{sub_name}" recursive_find_attn_block(sub_name, sub_module) recursive_find_attn_block("", self) # NOTE: we have to check if the deprecated parameters are in the state dict # because it is possible we are loading from a state dict that was already # converted for path in deprecated_attention_block_paths: # group_norm path stays the same # query -> to_q if f"{path}.query.weight" in state_dict: state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") if f"{path}.query.bias" in state_dict: state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") # key -> to_k if f"{path}.key.weight" in state_dict: state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") if f"{path}.key.bias" in state_dict: state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") # value -> to_v if f"{path}.value.weight" in state_dict: state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") if f"{path}.value.bias" in state_dict: state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") # proj_attn -> to_out.0 if f"{path}.proj_attn.weight" in state_dict: state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") if f"{path}.proj_attn.bias" in state_dict: state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") def _temp_convert_self_to_deprecated_attention_blocks(self): deprecated_attention_block_modules = [] def recursive_find_attn_block(module): if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: deprecated_attention_block_modules.append(module) for sub_module in module.children(): recursive_find_attn_block(sub_module) recursive_find_attn_block(self) for module in deprecated_attention_block_modules: module.query = module.to_q module.key = module.to_k module.value = module.to_v module.proj_attn = module.to_out[0] # We don't _have_ to delete the old attributes, but it's helpful to ensure # that _all_ the weights are loaded into the new attributes and we're not # making an incorrect assumption that this model should be converted when # it really shouldn't be. del module.to_q del module.to_k del module.to_v del module.to_out def _undo_temp_convert_self_to_deprecated_attention_blocks(self): deprecated_attention_block_modules = [] def recursive_find_attn_block(module): if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: deprecated_attention_block_modules.append(module) for sub_module in module.children(): recursive_find_attn_block(sub_module) recursive_find_attn_block(self) for module in deprecated_attention_block_modules: module.to_q = module.query module.to_k = module.key module.to_v = module.value module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) del module.query del module.key del module.value del module.proj_attn ================================================ FILE: 4DoF/diffusers/models/prior_transformer.py ================================================ from dataclasses import dataclass from typing import Dict, Optional, Union import torch import torch.nn.functional as F from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .attention import BasicTransformerBlock from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin @dataclass class PriorTransformerOutput(BaseOutput): """ The output of [`PriorTransformer`]. Args: predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): The predicted CLIP image embedding conditioned on the CLIP text embedding input. """ predicted_image_embedding: torch.FloatTensor class PriorTransformer(ModelMixin, ConfigMixin): """ A Prior Transformer model. Parameters: num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` num_embeddings (`int`, *optional*, defaults to 77): The number of embeddings of the model input `hidden_states` additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + additional_embeddings`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. time_embed_act_fn (`str`, *optional*, defaults to 'silu'): The activation function to use to create timestep embeddings. norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before passing to Transformer blocks. Set it to `None` if normalization is not needed. embedding_proj_norm_type (`str`, *optional*, defaults to None): The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not needed. encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if `encoder_hidden_states` is `None`. added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot product between the text embedding and image embedding as proposed in the unclip paper https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. If None, will be set to `num_attention_heads * attention_head_dim` embedding_proj_dim (`int`, *optional*, default to None): The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. clip_embed_dim (`int`, *optional*, default to None): The dimension of the output. If None, will be set to `embedding_dim`. """ @register_to_config def __init__( self, num_attention_heads: int = 32, attention_head_dim: int = 64, num_layers: int = 20, embedding_dim: int = 768, num_embeddings=77, additional_embeddings=4, dropout: float = 0.0, time_embed_act_fn: str = "silu", norm_in_type: Optional[str] = None, # layer embedding_proj_norm_type: Optional[str] = None, # layer encoder_hid_proj_type: Optional[str] = "linear", # linear added_emb_type: Optional[str] = "prd", # prd time_embed_dim: Optional[int] = None, embedding_proj_dim: Optional[int] = None, clip_embed_dim: Optional[int] = None, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.additional_embeddings = additional_embeddings time_embed_dim = time_embed_dim or inner_dim embedding_proj_dim = embedding_proj_dim or embedding_dim clip_embed_dim = clip_embed_dim or embedding_dim self.time_proj = Timesteps(inner_dim, True, 0) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) self.proj_in = nn.Linear(embedding_dim, inner_dim) if embedding_proj_norm_type is None: self.embedding_proj_norm = None elif embedding_proj_norm_type == "layer": self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) else: raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) if encoder_hid_proj_type is None: self.encoder_hidden_states_proj = None elif encoder_hid_proj_type == "linear": self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) else: raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) if added_emb_type == "prd": self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) elif added_emb_type is None: self.prd_embedding = None else: raise ValueError( f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." ) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, activation_fn="gelu", attention_bias=True, ) for d in range(num_layers) ] ) if norm_in_type == "layer": self.norm_in = nn.LayerNorm(inner_dim) elif norm_in_type is None: self.norm_in = None else: raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") self.norm_out = nn.LayerNorm(inner_dim) self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) causal_attention_mask = torch.full( [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 ) causal_attention_mask.triu_(1) causal_attention_mask = causal_attention_mask[None, ...] self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def forward( self, hidden_states, timestep: Union[torch.Tensor, float, int], proj_embedding: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, ): """ The [`PriorTransformer`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): The currently predicted image embeddings. timestep (`torch.LongTensor`): Current denoising step. proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): Projected embedding vector the denoising process is conditioned on. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): Hidden states of the text embeddings the denoising process is conditioned on. attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): Text mask for the text embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain tuple. Returns: [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ batch_size = hidden_states.shape[0] timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) timesteps_projected = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might be fp16, so we need to cast here. timesteps_projected = timesteps_projected.to(dtype=self.dtype) time_embeddings = self.time_embedding(timesteps_projected) if self.embedding_proj_norm is not None: proj_embedding = self.embedding_proj_norm(proj_embedding) proj_embeddings = self.embedding_proj(proj_embedding) if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") hidden_states = self.proj_in(hidden_states) positional_embeddings = self.positional_embedding.to(hidden_states.dtype) additional_embeds = [] additional_embeddings_len = 0 if encoder_hidden_states is not None: additional_embeds.append(encoder_hidden_states) additional_embeddings_len += encoder_hidden_states.shape[1] if len(proj_embeddings.shape) == 2: proj_embeddings = proj_embeddings[:, None, :] if len(hidden_states.shape) == 2: hidden_states = hidden_states[:, None, :] additional_embeds = additional_embeds + [ proj_embeddings, time_embeddings[:, None, :], hidden_states, ] if self.prd_embedding is not None: prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) additional_embeds.append(prd_embedding) hidden_states = torch.cat( additional_embeds, dim=1, ) # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 if positional_embeddings.shape[1] < hidden_states.shape[1]: positional_embeddings = F.pad( positional_embeddings, ( 0, 0, additional_embeddings_len, self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, ), value=0.0, ) hidden_states = hidden_states + positional_embeddings if attention_mask is not None: attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) if self.norm_in is not None: hidden_states = self.norm_in(hidden_states) for block in self.transformer_blocks: hidden_states = block(hidden_states, attention_mask=attention_mask) hidden_states = self.norm_out(hidden_states) if self.prd_embedding is not None: hidden_states = hidden_states[:, -1] else: hidden_states = hidden_states[:, additional_embeddings_len:] predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) if not return_dict: return (predicted_image_embedding,) return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) def post_process_latents(self, prior_latents): prior_latents = (prior_latents * self.clip_std) + self.clip_mean return prior_latents ================================================ FILE: 4DoF/diffusers/models/resnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import SpatialNorm class Upsample1D(nn.Module): """A 1D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name self.conv = None if use_conv_transpose: self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) elif use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) def forward(self, inputs): assert inputs.shape[1] == self.channels if self.use_conv_transpose: return self.conv(inputs) outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") if self.use_conv: outputs = self.conv(outputs) return outputs class Downsample1D(nn.Module): """A 1D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) def forward(self, inputs): assert inputs.shape[1] == self.channels return self.conv(inputs) class Upsample2D(nn.Module): """A 2D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name conv = None if use_conv_transpose: conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) elif use_conv: conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.conv = conv else: self.Conv2d_0 = conv def forward(self, hidden_states, output_size=None): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: return self.conv(hidden_states) # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch # https://github.com/pytorch/pytorch/issues/86679 dtype = hidden_states.dtype if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: hidden_states = hidden_states.contiguous() # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16: hidden_states = hidden_states.to(dtype) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": hidden_states = self.conv(hidden_states) else: hidden_states = self.Conv2d_0(hidden_states) return hidden_states class Downsample2D(nn.Module): """A 2D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.Conv2d_0 = conv self.conv = conv elif name == "Conv2d_0": self.conv = conv else: self.conv = conv def forward(self, hidden_states): assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states) return hidden_states class FirUpsample2D(nn.Module): """A 2D FIR upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. fir_kernel (`tuple`, default `(1, 3, 3, 1)`): kernel for the FIR filter. """ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.use_conv = use_conv self.fir_kernel = fir_kernel self.out_channels = out_channels def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `Conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. weight: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as `hidden_states`. """ assert isinstance(factor, int) and factor >= 1 # Setup filter kernel. if kernel is None: kernel = [1] * factor # setup kernel kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) if self.use_conv: convH = weight.shape[2] convW = weight.shape[3] inC = weight.shape[1] pad_value = (kernel.shape[0] - factor) - (convW - 1) stride = (factor, factor) # Determine data dimensions. output_shape = ( (hidden_states.shape[2] - 1) * factor + convH, (hidden_states.shape[3] - 1) * factor + convW, ) output_padding = ( output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 num_groups = hidden_states.shape[1] // inC # Transpose weights. weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) inverse_conv = F.conv_transpose2d( hidden_states, weight, stride=stride, output_padding=output_padding, padding=0 ) output = upfirdn2d_native( inverse_conv, torch.tensor(kernel, device=inverse_conv.device), pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), ) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output def forward(self, hidden_states): if self.use_conv: height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return height class FirDownsample2D(nn.Module): """A 2D FIR downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. fir_kernel (`tuple`, default `(1, 3, 3, 1)`): kernel for the FIR filter. """ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.fir_kernel = fir_kernel self.use_conv = use_conv self.out_channels = out_channels def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `Conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. weight: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor # setup kernel kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * gain if self.use_conv: _, _, convH, convW = weight.shape pad_value = (kernel.shape[0] - factor) + (convW - 1) stride_value = [factor, factor] upfirdn_input = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), pad=((pad_value + 1) // 2, pad_value // 2), ) output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) return output def forward(self, hidden_states): if self.use_conv: downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return hidden_states # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead class KDownsample2D(nn.Module): def __init__(self, pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) def forward(self, inputs): inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel return F.conv2d(inputs, weight, stride=2) class KUpsample2D(nn.Module): def __init__(self, pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) def forward(self, inputs): inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) class ResnetBlock2D(nn.Module): r""" A Resnet block. Parameters: in_channels (`int`): The number of channels in the input. out_channels (`int`, *optional*, default to be `None`): The number of output channels for the first conv2d layer. If None, same as `in_channels`. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. groups_out (`int`, *optional*, default to None): The number of groups to use for the second normalization layer. if set to None, same as `groups`. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or "ada_group" for a stronger conditioning with scale and shift. kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. use_in_shortcut (`bool`, *optional*, default to `True`): If `True`, add a 1x1 nn.conv2d layer for skip-connection. up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the `conv_shortcut` output. conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. If None, same as `out_channels`. """ def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, groups_out=None, pre_norm=True, eps=1e-6, non_linearity="swish", skip_time_act=False, time_embedding_norm="default", # default, scale_shift, ada_group, spatial kernel=None, output_scale_factor=1.0, use_in_shortcut=None, up=False, down=False, conv_shortcut_bias: bool = True, conv_2d_out_channels: Optional[int] = None, ): super().__init__() self.pre_norm = pre_norm self.pre_norm = True self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.up = up self.down = down self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act if groups_out is None: groups_out = groups if self.time_embedding_norm == "ada_group": self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) elif self.time_embedding_norm == "spatial": self.norm1 = SpatialNorm(in_channels, temb_channels) else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") else: self.time_emb_proj = None if self.time_embedding_norm == "ada_group": self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) elif self.time_embedding_norm == "spatial": self.norm2 = SpatialNorm(out_channels, temb_channels) else: self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) self.upsample = self.downsample = None if self.up: if kernel == "fir": fir_kernel = (1, 3, 3, 1) self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) elif kernel == "sde_vp": self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") else: self.upsample = Upsample2D(in_channels, use_conv=False) elif self.down: if kernel == "fir": fir_kernel = (1, 3, 3, 1) self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) elif kernel == "sde_vp": self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) else: self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) def forward(self, input_tensor, temb): hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm1(hidden_states, temb) else: hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) temb = self.time_emb_proj(temb)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm2(hidden_states, temb) else: hidden_states = self.norm2(hidden_states) if temb is not None and self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor # unet_rl.py def rearrange_dims(tensor): if len(tensor.shape) == 2: return tensor[:, :, None] if len(tensor.shape) == 3: return tensor[:, :, None, :] elif len(tensor.shape) == 4: return tensor[:, :, 0, :] else: raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish """ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) self.group_norm = nn.GroupNorm(n_groups, out_channels) self.mish = nn.Mish() def forward(self, inputs): intermediate_repr = self.conv1d(inputs) intermediate_repr = rearrange_dims(intermediate_repr) intermediate_repr = self.group_norm(intermediate_repr) intermediate_repr = rearrange_dims(intermediate_repr) output = self.mish(intermediate_repr) return output # unet_rl.py class ResidualTemporalBlock1D(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): super().__init__() self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) self.time_emb_act = nn.Mish() self.time_emb = nn.Linear(embed_dim, out_channels) self.residual_conv = ( nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() ) def forward(self, inputs, t): """ Args: inputs : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x out_channels x horizon ] """ t = self.time_emb_act(t) t = self.time_emb(t) out = self.conv_in(inputs) + rearrange_dims(t) out = self.conv_out(out) return out + self.residual_conv(inputs) def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: multiple of the upsampling factor. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H * factor, W * factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Downsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the downsampling factor. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H // factor, W // factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * gain pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) ) return output def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): up_x = up_y = up down_x = down_y = down pad_x0 = pad_y0 = pad[0] pad_x1 = pad_y1 = pad[1] _, channel, in_h, in_w = tensor.shape tensor = tensor.reshape(-1, in_h, in_w, 1) _, in_h, in_w, minor = tensor.shape kernel_h, kernel_w = kernel.shape out = tensor.view(-1, in_h, 1, in_w, 1, minor) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) out = out.to(tensor.device) # Move back to mps if necessary out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) out = out[:, ::down_y, ::down_x, :] out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) class TemporalConvLayer(nn.Module): """ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 """ def __init__(self, in_dim, out_dim=None, dropout=0.0): super().__init__() out_dim = out_dim or in_dim self.in_dim = in_dim self.out_dim = out_dim # conv layers self.conv1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) ) self.conv2 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv3 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv4 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) # zero out the last layer params,so the conv block is identity nn.init.zeros_(self.conv4[-1].weight) nn.init.zeros_(self.conv4[-1].bias) def forward(self, hidden_states, num_frames=1): hidden_states = ( hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) ) identity = hidden_states hidden_states = self.conv1(hidden_states) hidden_states = self.conv2(hidden_states) hidden_states = self.conv3(hidden_states) hidden_states = self.conv4(hidden_states) hidden_states = identity + hidden_states hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape( (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] ) return hidden_states ================================================ FILE: 4DoF/diffusers/models/resnet_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import flax.linen as nn import jax import jax.numpy as jnp class FlaxUpsample2D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, hidden_states): batch, height, width, channels = hidden_states.shape hidden_states = jax.image.resize( hidden_states, shape=(batch, height * 2, width * 2, channels), method="nearest", ) hidden_states = self.conv(hidden_states) return hidden_states class FlaxDownsample2D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)), # padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states): # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim # hidden_states = jnp.pad(hidden_states, pad_width=pad) hidden_states = self.conv(hidden_states) return hidden_states class FlaxResnetBlock2D(nn.Module): in_channels: int out_channels: int = None dropout_prob: float = 0.0 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.dropout = nn.Dropout(self.dropout_prob) self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut self.conv_shortcut = None if use_nin_shortcut: self.conv_shortcut = nn.Conv( out_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states, temb, deterministic=True): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) temb = self.time_emb_proj(nn.swish(temb)) temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.dropout(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return hidden_states + residual ================================================ FILE: 4DoF/diffusers/models/t5_film_transformer.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from .attention_processor import Attention from .embeddings import get_timestep_embedding from .modeling_utils import ModelMixin class T5FilmDecoder(ModelMixin, ConfigMixin): @register_to_config def __init__( self, input_dims: int = 128, targets_length: int = 256, max_decoder_noise_time: float = 2000.0, d_model: int = 768, num_layers: int = 12, num_heads: int = 12, d_kv: int = 64, d_ff: int = 2048, dropout_rate: float = 0.1, ): super().__init__() self.conditioning_emb = nn.Sequential( nn.Linear(d_model, d_model * 4, bias=False), nn.SiLU(), nn.Linear(d_model * 4, d_model * 4, bias=False), nn.SiLU(), ) self.position_encoding = nn.Embedding(targets_length, d_model) self.position_encoding.weight.requires_grad = False self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) self.dropout = nn.Dropout(p=dropout_rate) self.decoders = nn.ModuleList() for lyr_num in range(num_layers): # FiLM conditional T5 decoder lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) self.decoders.append(lyr) self.decoder_norm = T5LayerNorm(d_model) self.post_dropout = nn.Dropout(p=dropout_rate) self.spec_out = nn.Linear(d_model, input_dims, bias=False) def encoder_decoder_mask(self, query_input, key_input): mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) return mask.unsqueeze(-3) def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): batch, _, _ = decoder_input_tokens.shape assert decoder_noise_time.shape == (batch,) # decoder_noise_time is in [0, 1), so rescale to expected timing range. time_steps = get_timestep_embedding( decoder_noise_time * self.config.max_decoder_noise_time, embedding_dim=self.config.d_model, max_period=self.config.max_decoder_noise_time, ).to(dtype=self.dtype) conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) seq_length = decoder_input_tokens.shape[1] # If we want to use relative positions for audio context, we can just offset # this sequence by the length of encodings_and_masks. decoder_positions = torch.broadcast_to( torch.arange(seq_length, device=decoder_input_tokens.device), (batch, seq_length), ) position_encodings = self.position_encoding(decoder_positions) inputs = self.continuous_inputs_projection(decoder_input_tokens) inputs += position_encodings y = self.dropout(inputs) # decoder: No padding present. decoder_mask = torch.ones( decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype ) # Translate encoding masks to encoder-decoder masks. encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] # cross attend style: concat encodings encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) for lyr in self.decoders: y = lyr( y, conditioning_emb=conditioning_emb, encoder_hidden_states=encoded, encoder_attention_mask=encoder_decoder_mask, )[0] y = self.decoder_norm(y) y = self.post_dropout(y) spec_out = self.spec_out(y) return spec_out class DecoderLayer(nn.Module): def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6): super().__init__() self.layer = nn.ModuleList() # cond self attention: layer 0 self.layer.append( T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) ) # cross attention: layer 1 self.layer.append( T5LayerCrossAttention( d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon, ) ) # Film Cond MLP + dropout: last layer self.layer.append( T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) ) def forward( self, hidden_states, conditioning_emb=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, ): hidden_states = self.layer[0]( hidden_states, conditioning_emb=conditioning_emb, attention_mask=attention_mask, ) if encoder_hidden_states is not None: encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( encoder_hidden_states.dtype ) hidden_states = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_extended_attention_mask, ) # Apply Film Conditional Feed Forward layer hidden_states = self.layer[-1](hidden_states, conditioning_emb) return (hidden_states,) class T5LayerSelfAttentionCond(nn.Module): def __init__(self, d_model, d_kv, num_heads, dropout_rate): super().__init__() self.layer_norm = T5LayerNorm(d_model) self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) self.dropout = nn.Dropout(dropout_rate) def forward( self, hidden_states, conditioning_emb=None, attention_mask=None, ): # pre_self_attention_layer_norm normed_hidden_states = self.layer_norm(hidden_states) if conditioning_emb is not None: normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) # Self-attention block attention_output = self.attention(normed_hidden_states) hidden_states = hidden_states + self.dropout(attention_output) return hidden_states class T5LayerCrossAttention(nn.Module): def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon): super().__init__() self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.dropout = nn.Dropout(dropout_rate) def forward( self, hidden_states, key_value_states=None, attention_mask=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( normed_hidden_states, encoder_hidden_states=key_value_states, attention_mask=attention_mask.squeeze(1), ) layer_output = hidden_states + self.dropout(attention_output) return layer_output class T5LayerFFCond(nn.Module): def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon): super().__init__() self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.dropout = nn.Dropout(dropout_rate) def forward(self, hidden_states, conditioning_emb=None): forwarded_states = self.layer_norm(hidden_states) if conditioning_emb is not None: forwarded_states = self.film(forwarded_states, conditioning_emb) forwarded_states = self.DenseReluDense(forwarded_states) hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states class T5DenseGatedActDense(nn.Module): def __init__(self, d_model, d_ff, dropout_rate): super().__init__() self.wi_0 = nn.Linear(d_model, d_ff, bias=False) self.wi_1 = nn.Linear(d_model, d_ff, bias=False) self.wo = nn.Linear(d_ff, d_model, bias=False) self.dropout = nn.Dropout(dropout_rate) self.act = NewGELUActivation() def forward(self, hidden_states): hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) hidden_states = self.wo(hidden_states) return hidden_states class T5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states class NewGELUActivation(nn.Module): """ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 """ def forward(self, input: torch.Tensor) -> torch.Tensor: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) class T5FiLMLayer(nn.Module): """ FiLM Layer """ def __init__(self, in_features, out_features): super().__init__() self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) def forward(self, x, conditioning_emb): emb = self.scale_bias(conditioning_emb) scale, shift = torch.chunk(emb, 2, -1) x = x * (1 + scale) + shift return x ================================================ FILE: 4DoF/diffusers/models/transformer_2d.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..models.embeddings import ImagePositionalEmbeddings from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock from .embeddings import PatchEmbed from .modeling_utils import ModelMixin @dataclass class Transformer2DModelOutput(BaseOutput): """ The output of [`Transformer2DModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. """ sample: torch.FloatTensor class Transformer2DModel(ModelMixin, ConfigMixin): """ A 2D Transformer model for image-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. num_vector_embeds (`int`, *optional*): The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. num_embeds_ada_norm ( `int`, *optional*): The number of diffusion steps used during training. Pass if at least one of the norm_layers is `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the `TransformerBlocks` attention should contain a bias parameter. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_vectorized = num_vector_embeds is not None self.is_input_patches = in_channels is not None and patch_size is not None if norm_type == "layer_norm" and num_embeds_ada_norm is not None: deprecation_message = ( f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" " would be very nice if you could open a Pull request for the `transformer/config.json` file" ) deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) norm_type = "ada_norm" if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is None." ) elif self.is_input_vectorized and self.is_input_patches: raise ValueError( f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" " sure that either `num_vector_embeds` or `num_patches` is None." ) elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: raise ValueError( f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." ) # 2. Define input layers if self.is_input_continuous: self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" self.height = sample_size self.width = sample_size self.num_vector_embeds = num_vector_embeds self.num_latent_pixels = self.height * self.width self.latent_image_embedding = ImagePositionalEmbeddings( num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width ) elif self.is_input_patches: assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" self.height = sample_size self.width = sample_size self.patch_size = patch_size self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, ) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, ) for d in range(num_layers) ] ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels if self.is_input_continuous: # TODO: should use out_channels for continuous projections if use_linear_projection: self.proj_out = nn.Linear(inner_dim, in_channels) else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) elif self.is_input_patches: self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None, posemb: Optional = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ The [`Transformer2DModel`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): Input `hidden_states`. encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. encoder_attention_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batch, sequence_length)` True = keep, False = discard. * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: hidden_states = self.pos_embed(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, posemb=posemb, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual elif self.is_input_vectorized: hidden_states = self.norm_out(hidden_states) logits = self.out(hidden_states) # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) logits = logits.permute(0, 2, 1) # log(p(x_0)) output = F.log_softmax(logits.double(), dim=1).float() elif self.is_input_patches: # TODO: cleanup! conditioning = self.transformer_blocks[0].norm1.emb( timestep, class_labels, hidden_dtype=hidden_states.dtype ) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) # unpatchify height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) ================================================ FILE: 4DoF/diffusers/models/transformer_temporal.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional import torch from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .attention import BasicTransformerBlock from .modeling_utils import ModelMixin @dataclass class TransformerTemporalModelOutput(BaseOutput): """ The output of [`TransformerTemporalModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. """ sample: torch.FloatTensor class TransformerTemporalModel(ModelMixin, ConfigMixin): """ A Transformer model for video-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. attention_bias (`bool`, *optional*): Configure if the `TransformerBlock` attention should contain a bias parameter. double_self_attention (`bool`, *optional*): Configure if each `TransformerBlock` should contain two self-attention layers. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, activation_fn: str = "geglu", norm_elementwise_affine: bool = True, double_self_attention: bool = True, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, double_self_attention=double_self_attention, norm_elementwise_affine=norm_elementwise_affine, ) for d in range(num_layers) ] ) self.proj_out = nn.Linear(inner_dim, in_channels) def forward( self, hidden_states, encoder_hidden_states=None, timestep=None, class_labels=None, num_frames=1, cross_attention_kwargs=None, return_dict: bool = True, ): """ The [`TransformerTemporal`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): Input hidden_states. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # 1. Input batch_frames, channel, height, width = hidden_states.shape batch_size = batch_frames // num_frames residual = hidden_states hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) hidden_states = hidden_states.permute(0, 2, 1, 3, 4) hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, channel, num_frames) .permute(0, 3, 4, 1, 2) .contiguous() ) hidden_states = hidden_states.reshape(batch_frames, channel, height, width) output = hidden_states + residual if not return_dict: return (output,) return TransformerTemporalModelOutput(sample=output) ================================================ FILE: 4DoF/diffusers/models/unet_1d.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @dataclass class UNet1DOutput(BaseOutput): """ The output of [`UNet1DModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): The hidden states output from the last layer of the model. """ sample: torch.FloatTensor class UNet1DModel(ModelMixin, ConfigMixin): r""" A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. extra_in_channels (`int`, *optional*, defaults to 0): Number of additional channels to be added to the input of the first down block. Useful for cases where the input data has more channels than what the model was initially designed for. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sin to cos for Fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): Tuple of block output channels. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. downsample_each_block (`int`, *optional*, defaults to `False`): Experimental feature for using a UNet without upsampling. """ @register_to_config def __init__( self, sample_size: int = 65536, sample_rate: Optional[int] = None, in_channels: int = 2, out_channels: int = 2, extra_in_channels: int = 0, time_embedding_type: str = "fourier", flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, freq_shift: float = 0.0, down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), mid_block_type: Tuple[str] = "UNetMidBlock1D", out_block_type: str = None, block_out_channels: Tuple[int] = (32, 32, 64), act_fn: str = None, norm_num_groups: int = 8, layers_per_block: int = 1, downsample_each_block: bool = False, ): super().__init__() self.sample_size = sample_size # time if time_embedding_type == "fourier": self.time_proj = GaussianFourierProjection( embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps( block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift ) timestep_input_dim = block_out_channels[0] if use_timestep_embedding: time_embed_dim = block_out_channels[0] * 4 self.time_mlp = TimestepEmbedding( in_channels=timestep_input_dim, time_embed_dim=time_embed_dim, act_fn=act_fn, out_dim=block_out_channels[0], ) self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) self.out_block = None # down output_channel = in_channels for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] if i == 0: input_channel += extra_in_channels is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], add_downsample=not is_final_block or downsample_each_block, ) self.down_blocks.append(down_block) # mid self.mid_block = get_mid_block( mid_block_type, in_channels=block_out_channels[-1], mid_channels=block_out_channels[-1], out_channels=block_out_channels[-1], embed_dim=block_out_channels[0], num_layers=layers_per_block, add_downsample=downsample_each_block, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] if out_block_type is None: final_upsample_channels = out_channels else: final_upsample_channels = block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = ( reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels ) is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=layers_per_block, in_channels=prev_output_channel, out_channels=output_channel, temb_channels=block_out_channels[0], add_upsample=not is_final_block, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.out_block = get_out_block( out_block_type=out_block_type, num_groups_out=num_groups_out, embed_dim=block_out_channels[0], out_channels=out_channels, act_fn=act_fn, fc_dim=block_out_channels[-1] // 4, ) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], return_dict: bool = True, ) -> Union[UNet1DOutput, Tuple]: r""" The [`UNet1DModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. Returns: [`~models.unet_1d.UNet1DOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # 1. time timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) timestep_embed = self.time_proj(timesteps) if self.config.use_timestep_embedding: timestep_embed = self.time_mlp(timestep_embed) else: timestep_embed = timestep_embed[..., None] timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) # 2. down down_block_res_samples = () for downsample_block in self.down_blocks: sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) down_block_res_samples += res_samples # 3. mid if self.mid_block: sample = self.mid_block(sample, timestep_embed) # 4. up for i, upsample_block in enumerate(self.up_blocks): res_samples = down_block_res_samples[-1:] down_block_res_samples = down_block_res_samples[:-1] sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) # 5. post-process if self.out_block: sample = self.out_block(sample, timestep_embed) if not return_dict: return (sample,) return UNet1DOutput(sample=sample) ================================================ FILE: 4DoF/diffusers/models/unet_1d_blocks.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import torch.nn.functional as F from torch import nn from .activations import get_activation from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims class DownResnetBlock1D(nn.Module): def __init__( self, in_channels, out_channels=None, num_layers=1, conv_shortcut=False, temb_channels=32, groups=32, groups_out=None, non_linearity=None, time_embedding_norm="default", output_scale_factor=1.0, add_downsample=True, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.time_embedding_norm = time_embedding_norm self.add_downsample = add_downsample self.output_scale_factor = output_scale_factor if groups_out is None: groups_out = groups # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] for _ in range(num_layers): resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) self.resnets = nn.ModuleList(resnets) if non_linearity is None: self.nonlinearity = None else: self.nonlinearity = get_activation(non_linearity) self.downsample = None if add_downsample: self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) def forward(self, hidden_states, temb=None): output_states = () hidden_states = self.resnets[0](hidden_states, temb) for resnet in self.resnets[1:]: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.nonlinearity is not None: hidden_states = self.nonlinearity(hidden_states) if self.downsample is not None: hidden_states = self.downsample(hidden_states) return hidden_states, output_states class UpResnetBlock1D(nn.Module): def __init__( self, in_channels, out_channels=None, num_layers=1, temb_channels=32, groups=32, groups_out=None, non_linearity=None, time_embedding_norm="default", output_scale_factor=1.0, add_upsample=True, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.time_embedding_norm = time_embedding_norm self.add_upsample = add_upsample self.output_scale_factor = output_scale_factor if groups_out is None: groups_out = groups # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] for _ in range(num_layers): resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) self.resnets = nn.ModuleList(resnets) if non_linearity is None: self.nonlinearity = None else: self.nonlinearity = get_activation(non_linearity) self.upsample = None if add_upsample: self.upsample = Upsample1D(out_channels, use_conv_transpose=True) def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): if res_hidden_states_tuple is not None: res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) hidden_states = self.resnets[0](hidden_states, temb) for resnet in self.resnets[1:]: hidden_states = resnet(hidden_states, temb) if self.nonlinearity is not None: hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: hidden_states = self.upsample(hidden_states) return hidden_states class ValueFunctionMidBlock1D(nn.Module): def __init__(self, in_channels, out_channels, embed_dim): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.embed_dim = embed_dim self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) self.down1 = Downsample1D(out_channels // 2, use_conv=True) self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) self.down2 = Downsample1D(out_channels // 4, use_conv=True) def forward(self, x, temb=None): x = self.res1(x, temb) x = self.down1(x) x = self.res2(x, temb) x = self.down2(x) return x class MidResTemporalBlock1D(nn.Module): def __init__( self, in_channels, out_channels, embed_dim, num_layers: int = 1, add_downsample: bool = False, add_upsample: bool = False, non_linearity=None, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.add_downsample = add_downsample # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] for _ in range(num_layers): resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) self.resnets = nn.ModuleList(resnets) if non_linearity is None: self.nonlinearity = None else: self.nonlinearity = get_activation(non_linearity) self.upsample = None if add_upsample: self.upsample = Downsample1D(out_channels, use_conv=True) self.downsample = None if add_downsample: self.downsample = Downsample1D(out_channels, use_conv=True) if self.upsample and self.downsample: raise ValueError("Block cannot downsample and upsample") def forward(self, hidden_states, temb): hidden_states = self.resnets[0](hidden_states, temb) for resnet in self.resnets[1:]: hidden_states = resnet(hidden_states, temb) if self.upsample: hidden_states = self.upsample(hidden_states) if self.downsample: self.downsample = self.downsample(hidden_states) return hidden_states class OutConv1DBlock(nn.Module): def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): super().__init__() self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) self.final_conv1d_act = get_activation(act_fn) self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) def forward(self, hidden_states, temb=None): hidden_states = self.final_conv1d_1(hidden_states) hidden_states = rearrange_dims(hidden_states) hidden_states = self.final_conv1d_gn(hidden_states) hidden_states = rearrange_dims(hidden_states) hidden_states = self.final_conv1d_act(hidden_states) hidden_states = self.final_conv1d_2(hidden_states) return hidden_states class OutValueFunctionBlock(nn.Module): def __init__(self, fc_dim, embed_dim): super().__init__() self.final_block = nn.ModuleList( [ nn.Linear(fc_dim + embed_dim, fc_dim // 2), nn.Mish(), nn.Linear(fc_dim // 2, 1), ] ) def forward(self, hidden_states, temb): hidden_states = hidden_states.view(hidden_states.shape[0], -1) hidden_states = torch.cat((hidden_states, temb), dim=-1) for layer in self.final_block: hidden_states = layer(hidden_states) return hidden_states _kernels = { "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], "lanczos3": [ 0.003689131001010537, 0.015056144446134567, -0.03399861603975296, -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 0.44638532400131226, 0.13550527393817902, -0.066637322306633, -0.03399861603975296, 0.015056144446134567, 0.003689131001010537, ], } class Downsample1d(nn.Module): def __init__(self, kernel="linear", pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) def forward(self, hidden_states): hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) weight[indices, indices] = kernel return F.conv1d(hidden_states, weight, stride=2) class Upsample1d(nn.Module): def __init__(self, kernel="linear", pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) * 2 self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) def forward(self, hidden_states, temb=None): hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) weight[indices, indices] = kernel return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) class SelfAttention1d(nn.Module): def __init__(self, in_channels, n_head=1, dropout_rate=0.0): super().__init__() self.channels = in_channels self.group_norm = nn.GroupNorm(1, num_channels=in_channels) self.num_heads = n_head self.query = nn.Linear(self.channels, self.channels) self.key = nn.Linear(self.channels, self.channels) self.value = nn.Linear(self.channels, self.channels) self.proj_attn = nn.Linear(self.channels, self.channels, bias=True) self.dropout = nn.Dropout(dropout_rate, inplace=True) def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) return new_projection def forward(self, hidden_states): residual = hidden_states batch, channel_dim, seq = hidden_states.shape hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.transpose(1, 2) query_proj = self.query(hidden_states) key_proj = self.key(hidden_states) value_proj = self.value(hidden_states) query_states = self.transpose_for_scores(query_proj) key_states = self.transpose_for_scores(key_proj) value_states = self.transpose_for_scores(value_proj) scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) attention_probs = torch.softmax(attention_scores, dim=-1) # compute attention output hidden_states = torch.matmul(attention_probs, value_states) hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) hidden_states = hidden_states.view(new_hidden_states_shape) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.transpose(1, 2) hidden_states = self.dropout(hidden_states) output = hidden_states + residual return output class ResConvBlock(nn.Module): def __init__(self, in_channels, mid_channels, out_channels, is_last=False): super().__init__() self.is_last = is_last self.has_conv_skip = in_channels != out_channels if self.has_conv_skip: self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) self.group_norm_1 = nn.GroupNorm(1, mid_channels) self.gelu_1 = nn.GELU() self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) if not self.is_last: self.group_norm_2 = nn.GroupNorm(1, out_channels) self.gelu_2 = nn.GELU() def forward(self, hidden_states): residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states hidden_states = self.conv_1(hidden_states) hidden_states = self.group_norm_1(hidden_states) hidden_states = self.gelu_1(hidden_states) hidden_states = self.conv_2(hidden_states) if not self.is_last: hidden_states = self.group_norm_2(hidden_states) hidden_states = self.gelu_2(hidden_states) output = hidden_states + residual return output class UNetMidBlock1D(nn.Module): def __init__(self, mid_channels, in_channels, out_channels=None): super().__init__() out_channels = in_channels if out_channels is None else out_channels # there is always at least one resnet self.down = Downsample1d("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(out_channels, out_channels // 32), ] self.up = Upsample1d(kernel="cubic") self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = self.down(hidden_states) for attn, resnet in zip(self.attentions, self.resnets): hidden_states = resnet(hidden_states) hidden_states = attn(hidden_states) hidden_states = self.up(hidden_states) return hidden_states class AttnDownBlock1D(nn.Module): def __init__(self, out_channels, in_channels, mid_channels=None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels self.down = Downsample1d("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(out_channels, out_channels // 32), ] self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = self.down(hidden_states) for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states) hidden_states = attn(hidden_states) return hidden_states, (hidden_states,) class DownBlock1D(nn.Module): def __init__(self, out_channels, in_channels, mid_channels=None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels self.down = Downsample1d("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = self.down(hidden_states) for resnet in self.resnets: hidden_states = resnet(hidden_states) return hidden_states, (hidden_states,) class DownBlock1DNoSkip(nn.Module): def __init__(self, out_channels, in_channels, mid_channels=None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = torch.cat([hidden_states, temb], dim=1) for resnet in self.resnets: hidden_states = resnet(hidden_states) return hidden_states, (hidden_states,) class AttnUpBlock1D(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(out_channels, out_channels // 32), ] self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states) hidden_states = attn(hidden_states) hidden_states = self.up(hidden_states) return hidden_states class UpBlock1D(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() mid_channels = in_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) for resnet in self.resnets: hidden_states = resnet(hidden_states) hidden_states = self.up(hidden_states) return hidden_states class UpBlock1DNoSkip(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() mid_channels = in_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), ] self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) for resnet in self.resnets: hidden_states = resnet(hidden_states) return hidden_states def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): if down_block_type == "DownResnetBlock1D": return DownResnetBlock1D( in_channels=in_channels, num_layers=num_layers, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, ) elif down_block_type == "DownBlock1D": return DownBlock1D(out_channels=out_channels, in_channels=in_channels) elif down_block_type == "AttnDownBlock1D": return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) elif down_block_type == "DownBlock1DNoSkip": return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) raise ValueError(f"{down_block_type} does not exist.") def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): if up_block_type == "UpResnetBlock1D": return UpResnetBlock1D( in_channels=in_channels, num_layers=num_layers, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, ) elif up_block_type == "UpBlock1D": return UpBlock1D(in_channels=in_channels, out_channels=out_channels) elif up_block_type == "AttnUpBlock1D": return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) elif up_block_type == "UpBlock1DNoSkip": return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) raise ValueError(f"{up_block_type} does not exist.") def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): if mid_block_type == "MidResTemporalBlock1D": return MidResTemporalBlock1D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim, add_downsample=add_downsample, ) elif mid_block_type == "ValueFunctionMidBlock1D": return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) elif mid_block_type == "UNetMidBlock1D": return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) raise ValueError(f"{mid_block_type} does not exist.") def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): if out_block_type == "OutConv1DBlock": return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) elif out_block_type == "ValueFunction": return OutValueFunctionBlock(fc_dim, embed_dim) return None ================================================ FILE: 4DoF/diffusers/models/unet_2d.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass class UNet2DOutput(BaseOutput): """ The output of [`UNet2DModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output from the last layer of the model. """ sample: torch.FloatTensor class UNet2DModel(ModelMixin, ConfigMixin): r""" A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - 1)`. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip sin to cos for Fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block types. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): Tuple of block output channels. layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. downsample_type (`str`, *optional*, defaults to `conv`): The downsample type for downsampling layers. Choose between "conv" and "resnet" upsample_type (`str`, *optional*, defaults to `conv`): The upsample type for upsampling layers. Choose between "conv" and "resnet" act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class conditioning with `class_embed_type` equal to `None`. """ @register_to_config def __init__( self, sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 3, out_channels: int = 3, center_input_sample: bool = False, time_embedding_type: str = "positional", freq_shift: int = 0, flip_sin_to_cos: bool = True, down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), block_out_channels: Tuple[int] = (224, 448, 672, 896), layers_per_block: int = 2, mid_block_scale_factor: float = 1, downsample_padding: int = 1, downsample_type: str = "conv", upsample_type: str = "conv", act_fn: str = "silu", attention_head_dim: Optional[int] = 8, norm_num_groups: int = 32, norm_eps: float = 1e-5, resnet_time_scale_shift: str = "default", add_attention: bool = True, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, ): super().__init__() self.sample_size = sample_size time_embed_dim = block_out_channels[0] * 4 # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) # input self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) # time if time_embedding_type == "fourier": self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) else: self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], resnet_groups=norm_num_groups, add_attention=add_attention, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DOutput, Tuple]: r""" The [`UNet2DModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. Returns: [`~models.unet_2d.UNet2DOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when doing class conditioning") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # 2. pre-process skip_sample = sample sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "skip_conv"): sample, res_samples, skip_sample = downsample_block( hidden_states=sample, temb=emb, skip_sample=skip_sample ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid sample = self.mid_block(sample, emb) # 5. up skip_sample = None for upsample_block in self.up_blocks: res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] if hasattr(upsample_block, "skip_conv"): sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) else: sample = upsample_block(sample, res_samples, emb) # 6. post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if skip_sample is not None: sample += skip_sample if self.config.time_embedding_type == "fourier": timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) sample = sample / timesteps if not return_dict: return (sample,) return UNet2DOutput(sample=sample) ================================================ FILE: 4DoF/diffusers/models/unet_2d_blocks.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn from ..utils import is_torch_version, logging from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, downsample_type=None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warn( f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": return DownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "ResnetDownsampleBlock2D": return ResnetDownsampleBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "AttnDownBlock2D": if add_downsample is False: downsample_type = None else: downsample_type = downsample_type or "conv" # default to 'conv' return AttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "SimpleCrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") return SimpleCrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnSkipDownBlock2D": return AttnSkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnDownEncoderBlock2D": return AttnDownEncoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "KDownBlock2D": return KDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) elif down_block_type == "KCrossAttnDownBlock2D": return KCrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, add_self_attention=True if not add_downsample else False, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, upsample_type=None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warn( f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": return UpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "ResnetUpsampleBlock2D": return ResnetUpsampleBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "SimpleCrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") return SimpleCrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": if add_upsample is False: upsample_type = None else: upsample_type = upsample_type or "conv" # default to 'conv' return AttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "AttnSkipUpBlock2D": return AttnSkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) elif up_block_type == "AttnUpDecoderBlock2D": return AttnUpDecoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) elif up_block_type == "KCrossAttnUpBlock2D": return KCrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, ) raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock2D(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim=1, output_scale_factor=1.0, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." ) attention_head_dim = in_channels for _ in range(num_layers): if self.add_attention: attentions.append( Attention( in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) else: attentions.append(None) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: hidden_states = attn(hidden_states, temb=temb) hidden_states = resnet(hidden_states, temb) return hidden_states class UNetMidBlock2DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] for _ in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, posemb: Optional = None, ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, posemb=posemb, )[0] hidden_states = resnet(hidden_states, temb) return hidden_states class UNetMidBlock2DSimpleCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() self.has_cross_attention = True self.attention_head_dim = attention_head_dim resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.num_heads = in_channels // self.attention_head_dim # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ] attentions = [] for _ in range(num_layers): processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=in_channels, cross_attention_dim=in_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) # resnet hidden_states = resnet(hidden_states, temb) return hidden_states class AttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, downsample_padding=1, downsample_type="conv", ): super().__init__() resnets = [] attentions = [] self.downsample_type = downsample_type if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if downsample_type == "conv": self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) elif downsample_type == "resnet": self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, down=True, ) ] ) else: self.downsamplers = None def forward(self, hidden_states, temb=None, upsample_size=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: if self.downsample_type == "resnet": hidden_states = downsampler(hidden_states, temb=temb) else: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, posemb: Optional = None, ): output_states = () for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), # transformer_2d hidden_states, encoder_hidden_states, None, # timestep None, # class_labels posemb, cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, posemb=posemb, )[0] output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class DownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class DownEncoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def forward(self, hidden_states): for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class AttnDownEncoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] attentions = [] if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def forward(self, hidden_states): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=None) hidden_states = attn(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class AttnSkipDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=np.sqrt(2.0), add_downsample=True, ): super().__init__() self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(in_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=32, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) if add_downsample: self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, down=True, kernel="fir", ) self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None self.downsamplers = None self.skip_conv = None def forward(self, hidden_states, temb=None, skip_sample=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) output_states += (hidden_states,) if self.downsamplers is not None: hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) hidden_states = self.skip_conv(skip_sample) + hidden_states output_states += (hidden_states,) return hidden_states, output_states, skip_sample class SkipDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, output_scale_factor=np.sqrt(2.0), add_downsample=True, downsample_padding=1, ): super().__init__() self.resnets = nn.ModuleList([]) for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(in_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if add_downsample: self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, down=True, kernel="fir", ) self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None self.downsamplers = None self.skip_conv = None def forward(self, hidden_states, temb=None, skip_sample=None): output_states = () for resnet in self.resnets: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) hidden_states = self.skip_conv(skip_sample) + hidden_states output_states += (hidden_states,) return hidden_states, output_states, skip_sample class ResnetDownsampleBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, skip_time_act=False, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, down=True, ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) return hidden_states, output_states class SimpleCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, cross_attention_dim=1280, output_scale_factor=1.0, add_downsample=True, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() self.has_cross_attention = True resnets = [] attentions = [] self.attention_head_dim = attention_head_dim self.num_heads = out_channels // self.attention_head_dim for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=out_channels, cross_attention_dim=out_channels, heads=self.num_heads, dim_head=attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, down=True, ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, mask, cross_attention_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) return hidden_states, output_states class KDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 4, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: int = 32, add_downsample=False, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=groups, groups_out=groups_out, eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: # YiYi's comments- might be able to use FirDownsample2D, look into details later self.downsamplers = nn.ModuleList([KDownsample2D()]) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states, output_states class KCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, cross_attention_dim: int, dropout: float = 0.0, num_layers: int = 4, resnet_group_size: int = 32, add_downsample=True, attention_head_dim: int = 64, add_self_attention: bool = False, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=groups, groups_out=groups_out, eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) attentions.append( KAttentionBlock( out_channels, out_channels // attention_head_dim, attention_head_dim, cross_attention_dim=cross_attention_dim, temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, cross_attention_norm="layer_norm", group_size=resnet_group_size, ) ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_downsample: self.downsamplers = nn.ModuleList([KDownsample2D()]) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, temb, attention_mask, cross_attention_kwargs, encoder_attention_mask, **ckpt_kwargs, ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if self.downsamplers is None: output_states += (None,) else: output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states, output_states class AttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, upsample_type="conv", ): super().__init__() resnets = [] attentions = [] self.upsample_type = upsample_type if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if upsample_type == "conv": self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) elif upsample_type == "resnet": self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, up=True, ) ] ) else: self.upsamplers = None def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: if self.upsample_type == "resnet": hidden_states = upsampler(hidden_states, temb=temb) else: hidden_states = upsampler(hidden_states) return hidden_states class CrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, posemb: Optional = None, ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, None, # timestep None, # class_labels posemb, cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, posemb=posemb, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpDecoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, temb_channels=None, ): super().__init__() resnets = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None def forward(self, hidden_states, temb=None): for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class AttnUpDecoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, add_upsample=True, temb_channels=None, ): super().__init__() resnets = [] attentions = [] if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None def forward(self, hidden_states, temb=None): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=temb) hidden_states = attn(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class AttnSkipUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=np.sqrt(2.0), add_upsample=True, ): super().__init__() self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(resnet_in_channels + res_skip_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." ) attention_head_dim = out_channels self.attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=32, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, up=True, kernel="fir", ) self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.skip_norm = torch.nn.GroupNorm( num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True ) self.act = nn.SiLU() else: self.resnet_up = None self.skip_conv = None self.skip_norm = None self.act = None def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = self.attentions[0](hidden_states) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) else: skip_sample = 0 if self.resnet_up is not None: skip_sample_states = self.skip_norm(hidden_states) skip_sample_states = self.act(skip_sample_states) skip_sample_states = self.skip_conv(skip_sample_states) skip_sample = skip_sample + skip_sample_states hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample class SkipUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, output_scale_factor=np.sqrt(2.0), add_upsample=True, upsample_padding=1, ): super().__init__() self.resnets = nn.ModuleList([]) for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min((resnet_in_channels + res_skip_channels) // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, up=True, kernel="fir", ) self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.skip_norm = torch.nn.GroupNorm( num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True ) self.act = nn.SiLU() else: self.resnet_up = None self.skip_conv = None self.skip_norm = None self.act = None def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) else: skip_sample = 0 if self.resnet_up is not None: skip_sample_states = self.skip_norm(hidden_states) skip_sample_states = self.act(skip_sample_states) skip_sample_states = self.skip_conv(skip_sample_states) skip_sample = skip_sample + skip_sample_states hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample class ResnetUpsampleBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, skip_time_act=False, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, up=True, ) ] ) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, temb) return hidden_states class SimpleCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.attention_head_dim = attention_head_dim self.num_heads = out_channels // self.attention_head_dim for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=out_channels, cross_attention_dim=out_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, up=True, ) ] ) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): # resnet # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, mask, cross_attention_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, temb) return hidden_states class KUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 5, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: Optional[int] = 32, add_upsample=True, ): super().__init__() resnets = [] k_in_channels = 2 * out_channels k_out_channels = in_channels num_layers = num_layers - 1 for i in range(num_layers): in_channels = k_in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=k_out_channels if (i == num_layers - 1) else out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([KUpsample2D()]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class KCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 4, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: int = 32, attention_head_dim=1, # attention dim_head cross_attention_dim: int = 768, add_upsample: bool = True, upcast_attention: bool = False, ): super().__init__() resnets = [] attentions = [] is_first_block = in_channels == out_channels == temb_channels is_middle_block = in_channels != out_channels add_self_attention = True if is_first_block else False self.has_cross_attention = True self.attention_head_dim = attention_head_dim # in_channels, and out_channels for the block (k-unet) k_in_channels = out_channels if is_first_block else 2 * out_channels k_out_channels = in_channels num_layers = num_layers - 1 for i in range(num_layers): in_channels = k_in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size if is_middle_block and (i == num_layers - 1): conv_2d_out_channels = k_out_channels else: conv_2d_out_channels = None resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, conv_2d_out_channels=conv_2d_out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) attentions.append( KAttentionBlock( k_out_channels if (i == num_layers - 1) else out_channels, k_out_channels // attention_head_dim if (i == num_layers - 1) else out_channels // attention_head_dim, attention_head_dim, cross_attention_dim=cross_attention_dim, temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, cross_attention_norm="layer_norm", upcast_attention=upcast_attention, ) ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_upsample: self.upsamplers = nn.ModuleList([KUpsample2D()]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, temb, attention_mask, cross_attention_kwargs, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states # can potentially later be renamed to `No-feed-forward` attention class KAttentionBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, upcast_attention: bool = False, temb_channels: int = 768, # for ada_group_norm add_self_attention: bool = False, cross_attention_norm: Optional[str] = None, group_size: int = 32, ): super().__init__() self.add_self_attention = add_self_attention # 1. Self-Attn if add_self_attention: self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=None, cross_attention_norm=None, ) # 2. Cross-Attn self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, cross_attention_norm=cross_attention_norm, ) def _to_3d(self, hidden_states, height, weight): return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) def _to_4d(self, hidden_states, height, weight): return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, # TODO: mark emb as non-optional (self.norm2 requires it). # requires assessing impact of change to positional param interface. emb: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} # 1. Self-Attention if self.add_self_attention: norm_hidden_states = self.norm1(hidden_states, emb) height, weight = norm_hidden_states.shape[2:] norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=None, attention_mask=attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) hidden_states = attn_output + hidden_states # 2. Cross-Attention/None norm_hidden_states = self.norm2(hidden_states, emb) height, weight = norm_hidden_states.shape[2:] norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) hidden_states = attn_output + hidden_states return hidden_states ================================================ FILE: 4DoF/diffusers/models/unet_2d_blocks_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import flax.linen as nn import jax.numpy as jnp from .attention_flax import FlaxTransformer2DModel from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D class FlaxCrossAttnDownBlock2D(nn.Module): r""" Cross Attention 2D Downsizing block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers num_attention_heads (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 num_attention_heads: int = 1 add_downsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] attentions = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads, depth=1, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) self.resnets = resnets self.attentions = attentions if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb, deterministic=deterministic) hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) output_states += (hidden_states,) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class FlaxDownBlock2D(nn.Module): r""" Flax 2D downsizing block Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 add_downsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, deterministic=True): output_states = () for resnet in self.resnets: hidden_states = resnet(hidden_states, temb, deterministic=deterministic) output_states += (hidden_states,) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class FlaxCrossAttnUpBlock2D(nn.Module): r""" Cross Attention 2D Upsampling block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers num_attention_heads (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsampling layer before each final output use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int prev_output_channel: int dropout: float = 0.0 num_layers: int = 1 num_attention_heads: int = 1 add_upsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] attentions = [] for i in range(self.num_layers): res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads, depth=1, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) self.resnets = resnets self.attentions = attentions if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUpBlock2D(nn.Module): r""" Flax 2D upsampling block Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels prev_output_channel (:obj:`int`): Output channels from the previous block dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int prev_output_channel: int dropout: float = 0.0 num_layers: int = 1 add_upsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUNetMidBlock2DCrossAttn(nn.Module): r""" Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers num_attention_heads (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dropout: float = 0.0 num_layers: int = 1 num_attention_heads: int = 1 use_linear_projection: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout_prob=self.dropout, dtype=self.dtype, ) ] attentions = [] for _ in range(self.num_layers): attn_block = FlaxTransformer2DModel( in_channels=self.in_channels, n_heads=self.num_attention_heads, d_head=self.in_channels // self.num_attention_heads, depth=1, use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) res_block = FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets self.attentions = attentions def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) return hidden_states ================================================ FILE: 4DoF/diffusers/models/unet_2d_condition.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .activations import get_activation from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from .modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, UpBlock2D, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) ================================================ FILE: 4DoF/diffusers/models/unet_2d_condition_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .modeling_flax_utils import FlaxModelMixin from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxCrossAttnUpBlock2D, FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, FlaxUpBlock2D, ) @flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): """ The output of [`FlaxUNet2DConditionModel`]. Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: jnp.ndarray @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its general usage and behavior. Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: sample_size (`int`, *optional*): The size of the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int` or `Tuple[int]`, *optional*): The number of attention heads. cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682). """ sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ) up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: Union[int, Tuple[int]] = 8 num_attention_heads: Optional[Union[int, Tuple[int]]] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 use_memory_efficient_attention: bool = False def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] def setup(self): block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 if self.num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = self.num_attention_heads or self.attention_head_dim # input self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # time self.time_proj = FlaxTimesteps( block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift ) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) only_cross_attention = self.only_cross_attention if isinstance(only_cross_attention, bool): only_cross_attention = (only_cross_attention,) * len(self.down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(self.down_block_types) # down down_blocks = [] output_channel = block_out_channels[0] for i, down_block_type in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 if down_block_type == "CrossAttnDownBlock2D": down_block = FlaxCrossAttnDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, num_attention_heads=num_attention_heads[i], add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) else: down_block = FlaxDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) self.down_blocks = down_blocks # mid self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], dropout=self.dropout, num_attention_heads=num_attention_heads[-1], use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) # up up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] is_final_block = i == len(block_out_channels) - 1 if up_block_type == "CrossAttnUpBlock2D": up_block = FlaxCrossAttnUpBlock2D( in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, num_attention_heads=reversed_num_attention_heads[i], add_upsample=not is_final_block, dropout=self.dropout, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) else: up_block = FlaxUpBlock2D( in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, add_upsample=not is_final_block, dropout=self.dropout, dtype=self.dtype, ) up_blocks.append(up_block) prev_output_channel = output_channel self.up_blocks = up_blocks # out self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.conv_out = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__( self, sample, timesteps, encoder_hidden_states, down_block_additional_residuals=None, mid_block_additional_residual=None, return_dict: bool = True, train: bool = False, ) -> Union[FlaxUNet2DConditionOutput, Tuple]: r""" Args: sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 1. time if not isinstance(timesteps, jnp.ndarray): timesteps = jnp.array([timesteps], dtype=jnp.int32) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: timesteps = timesteps.astype(dtype=jnp.float32) timesteps = jnp.expand_dims(timesteps, 0) t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) # 2. pre-process sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for down_block in self.down_blocks: if isinstance(down_block, FlaxCrossAttnDownBlock2D): sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample += down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) if mid_block_additional_residual is not None: sample += mid_block_additional_residual # 5. up for up_block in self.up_blocks: res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] if isinstance(up_block, FlaxCrossAttnUpBlock2D): sample = up_block( sample, temb=t_emb, encoder_hidden_states=encoder_hidden_states, res_hidden_states_tuple=res_samples, deterministic=not train, ) else: sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) # 6. post-process sample = self.conv_norm_out(sample) sample = nn.silu(sample) sample = self.conv_out(sample) sample = jnp.transpose(sample, (0, 3, 1, 2)) if not return_dict: return (sample,) return FlaxUNet2DConditionOutput(sample=sample) ================================================ FILE: 4DoF/diffusers/models/unet_3d_blocks.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D from .transformer_2d import Transformer2DModel from .transformer_temporal import TransformerTemporalModel def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, num_attention_heads, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=True, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", ): if down_block_type == "DownBlock3D": return DownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") return CrossAttnDownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, num_attention_heads, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=True, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", ): if up_block_type == "UpBlock3D": return UpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") return CrossAttnUpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock3DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=True, upcast_attention=False, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] temp_convs = [ TemporalConvLayer( in_channels, in_channels, dropout=0.1, ) ] attentions = [] temp_attentions = [] for _ in range(num_layers): attentions.append( Transformer2DModel( in_channels // num_attention_heads, num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) ) temp_attentions.append( TransformerTemporalModel( in_channels // num_attention_heads, num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( in_channels, in_channels, dropout=0.1, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) def forward( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, num_frames=1, cross_attention_kwargs=None, ): hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) for attn, temp_attn, resnet, temp_conv in zip( self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] ): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False )[0] hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) return hidden_states class CrossAttnDownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] temp_attentions = [] temp_convs = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, ) ) attentions.append( Transformer2DModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) temp_attentions.append( TransformerTemporalModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, num_frames=1, cross_attention_kwargs=None, ): # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions ): hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False )[0] output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class DownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] temp_convs = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None, num_frames=1): output_states = () for resnet, temp_conv in zip(self.resnets, self.temp_convs): hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnUpBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] temp_convs = [] attentions = [] temp_attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, ) ) attentions.append( Transformer2DModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) temp_attentions.append( TransformerTemporalModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None, attention_mask=None, num_frames=1, cross_attention_kwargs=None, ): # TODO(Patrick, William) - attention mask is not used for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions ): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlock3D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, ): super().__init__() resnets = [] temp_convs = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): for resnet, temp_conv in zip(self.resnets, self.temp_convs): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states ================================================ FILE: 4DoF/diffusers/models/unet_3d_condition.py ================================================ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. # Copyright 2023 The ModelScope Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, UpBlock3D, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet3DConditionOutput(BaseOutput): """ The output of [`UNet3DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. """ _supports_gradient_checkpointing = False @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, down_block_types: Tuple[str] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise NotImplementedError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # input conv_in_kernel = 3 conv_out_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], True, 0) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) self.transformer_in = TransformerTemporalModel( num_attention_heads=8, attention_head_dim=attention_head_dim, in_channels=block_out_channels[0], num_layers=1, ) # class embedding self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=False, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=False, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = nn.SiLU() else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def enable_forward_chunking(self, chunk_size=None, dim=0): """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" The [`UNet3DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. Returns: [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML num_frames = sample.shape[2] timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(repeats=num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) # 2. pre-process sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) sample = self.transformer_in( sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames, ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) # reshape to (batch, channel, framerate, width, height) sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) ================================================ FILE: 4DoF/diffusers/models/vae.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional import numpy as np import torch import torch.nn as nn from ..utils import BaseOutput, is_torch_version, randn_tensor from .attention_processor import SpatialNorm from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass class DecoderOutput(BaseOutput): """ Output of decoding method. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The decoded output sample from the last layer of the model. """ sample: torch.FloatTensor class Encoder(nn.Module): def __init__( self, in_channels=3, out_channels=3, down_block_types=("DownEncoderBlock2D",), block_out_channels=(64,), layers_per_block=2, norm_num_groups=32, act_fn="silu", double_z=True, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = torch.nn.Conv2d( in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, ) self.mid_block = None self.down_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=not is_final_block, resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=None, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, ) # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) self.gradient_checkpointing = False def forward(self, x): sample = x sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # down if is_torch_version(">=", "1.11.0"): for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), sample, use_reentrant=False ) # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, use_reentrant=False ) else: for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) # middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) else: # down for down_block in self.down_blocks: sample = down_block(sample) # middle sample = self.mid_block(sample) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class Decoder(nn.Module): def __init__( self, in_channels=3, out_channels=3, up_block_types=("UpDecoderBlock2D",), block_out_channels=(64,), layers_per_block=2, norm_num_groups=32, act_fn="silu", norm_type="group", # group, spatial ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv2d( in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1, ) self.mid_block = None self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, prev_output_channel=None, add_upsample=not is_final_block, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=temb_channels, resnet_time_scale_shift=norm_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.gradient_checkpointing = False def forward(self, z, latent_embeds=None): sample = z sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False ) else: # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) else: # middle sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class VectorQuantizer(nn.Module): """ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix multiplications and allows for post-hoc remapping of indices. """ # NOTE: due to a bug the beta term was applied to the wrong term. for # backwards compatibility we use the buggy version by default, but you can # specify legacy=False to fix it. def __init__( self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True ): super().__init__() self.n_e = n_e self.vq_embed_dim = vq_embed_dim self.beta = beta self.legacy = legacy self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] self.unknown_index = unknown_index # "random" or "extra" or integer if self.unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 print( f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices." ) else: self.re_embed = n_e self.sane_index_shape = sane_index_shape def remap_to_used(self, inds): ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) match = (inds[:, :, None] == used[None, None, ...]).long() new = match.argmax(-1) unknown = match.sum(2) < 1 if self.unknown_index == "random": new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) else: new[unknown] = self.unknown_index return new.reshape(ishape) def unmap_to_all(self, inds): ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) if self.re_embed > self.used.shape[0]: # extra token inds[inds >= self.used.shape[0]] = 0 # simply set to zero back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) return back.reshape(ishape) def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten z = z.permute(0, 2, 3, 1).contiguous() z_flattened = z.view(-1, self.vq_embed_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) perplexity = None min_encodings = None # compute loss for embedding if not self.legacy: loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) else: loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() if self.remap is not None: min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis min_encoding_indices = self.remap_to_used(min_encoding_indices) min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten if self.sane_index_shape: min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) return z_q, loss, (perplexity, min_encodings, min_encoding_indices) def get_codebook_entry(self, indices, shape): # shape specifying (batch, height, width, channel) if self.remap is not None: indices = indices.reshape(shape[0], -1) # add batch axis indices = self.unmap_to_all(indices) indices = indices.reshape(-1) # flatten again # get quantized latent vectors z_q = self.embedding(indices) if shape is not None: z_q = z_q.view(shape) # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() return z_q class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like( self.mean, device=self.parameters.device, dtype=self.parameters.dtype ) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: # make sure sample is on the same device as the parameters and has same dtype sample = randn_tensor( self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype ) x = self.mean + self.std * sample return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean ================================================ FILE: 4DoF/diffusers/models/vae_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers import math from functools import partial from typing import Tuple import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..utils import BaseOutput from .modeling_flax_utils import FlaxModelMixin @flax.struct.dataclass class FlaxDecoderOutput(BaseOutput): """ Output of decoding method. Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): The decoded output sample from the last layer of the model. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): The `dtype` of the parameters. """ sample: jnp.ndarray @flax.struct.dataclass class FlaxAutoencoderKLOutput(BaseOutput): """ Output of AutoencoderKL encoding method. Args: latent_dist (`FlaxDiagonalGaussianDistribution`): Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. """ latent_dist: "FlaxDiagonalGaussianDistribution" class FlaxUpsample2D(nn.Module): """ Flax implementation of 2D Upsample layer Args: in_channels (`int`): Input channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, hidden_states): batch, height, width, channels = hidden_states.shape hidden_states = jax.image.resize( hidden_states, shape=(batch, height * 2, width * 2, channels), method="nearest", ) hidden_states = self.conv(hidden_states) return hidden_states class FlaxDownsample2D(nn.Module): """ Flax implementation of 2D Downsample layer Args: in_channels (`int`): Input channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), strides=(2, 2), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states): pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim hidden_states = jnp.pad(hidden_states, pad_width=pad) hidden_states = self.conv(hidden_states) return hidden_states class FlaxResnetBlock2D(nn.Module): """ Flax implementation of 2D Resnet Block. Args: in_channels (`int`): Input channels out_channels (`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm. use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): Whether to use `nin_shortcut`. This activates a new layer inside ResNet block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int = None dropout: float = 0.0 groups: int = 32 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.dropout_layer = nn.Dropout(self.dropout) self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut self.conv_shortcut = None if use_nin_shortcut: self.conv_shortcut = nn.Conv( out_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states, deterministic=True): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return hidden_states + residual class FlaxAttentionBlock(nn.Module): r""" Flax Convolutional based multi-head attention block for diffusion-based VAE. Parameters: channels (:obj:`int`): Input channels num_head_channels (:obj:`int`, *optional*, defaults to `None`): Number of attention heads num_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ channels: int num_head_channels: int = None num_groups: int = 32 dtype: jnp.dtype = jnp.float32 def setup(self): self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 dense = partial(nn.Dense, self.channels, dtype=self.dtype) self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) self.query, self.key, self.value = dense(), dense(), dense() self.proj_attn = dense() def transpose_for_scores(self, projection): new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) new_projection = projection.reshape(new_projection_shape) # (B, T, H, D) -> (B, H, T, D) new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) return new_projection def __call__(self, hidden_states): residual = hidden_states batch, height, width, channels = hidden_states.shape hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.reshape((batch, height * width, channels)) query = self.query(hidden_states) key = self.key(hidden_states) value = self.value(hidden_states) # transpose query = self.transpose_for_scores(query) key = self.transpose_for_scores(key) value = self.transpose_for_scores(value) # compute attentions scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) attn_weights = nn.softmax(attn_weights, axis=-1) # attend to values hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) hidden_states = hidden_states.reshape(new_hidden_states_shape) hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.reshape((batch, height, width, channels)) hidden_states = hidden_states + residual return hidden_states class FlaxDownEncoderBlock2D(nn.Module): r""" Flax Resnet blocks-based Encoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 add_downsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) return hidden_states class FlaxUpDecoderBlock2D(nn.Module): r""" Flax Resnet blocks-based Decoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 add_upsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUNetMidBlock2D(nn.Module): r""" Flax Unet Mid-Block module. Parameters: in_channels (:obj:`int`): Input channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet and Attention block group norm num_attention_heads (:obj:`int`, *optional*, defaults to `1`): Number of attention heads for each attention block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 num_attention_heads: int = 1 dtype: jnp.dtype = jnp.float32 def setup(self): resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, groups=resnet_groups, dtype=self.dtype, ) ] attentions = [] for _ in range(self.num_layers): attn_block = FlaxAttentionBlock( channels=self.in_channels, num_head_channels=self.num_attention_heads, num_groups=resnet_groups, dtype=self.dtype, ) attentions.append(attn_block) res_block = FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, groups=resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets self.attentions = attentions def __call__(self, hidden_states, deterministic=True): hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, deterministic=deterministic) return hidden_states class FlaxEncoder(nn.Module): r""" Flax Implementation of VAE Encoder. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): DownEncoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) block_out_channels: Tuple[int] = (64,) layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" double_z: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels # in self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # downsampling down_blocks = [] output_channel = block_out_channels[0] for i, _ in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = FlaxDownEncoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers=self.layers_per_block, resnet_groups=self.norm_num_groups, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) self.down_blocks = down_blocks # middle self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_groups=self.norm_num_groups, num_attention_heads=None, dtype=self.dtype, ) # end conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( conv_out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, sample, deterministic: bool = True): # in sample = self.conv_in(sample) # downsampling for block in self.down_blocks: sample = block(sample, deterministic=deterministic) # middle sample = self.mid_block(sample, deterministic=deterministic) # end sample = self.conv_norm_out(sample) sample = nn.swish(sample) sample = self.conv_out(sample) return sample class FlaxDecoder(nn.Module): r""" Flax Implementation of VAE Decoder. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): UpDecoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 up_block_types: Tuple[str] = ("UpDecoderBlock2D",) block_out_channels: int = (64,) layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels # z to block_in self.conv_in = nn.Conv( block_out_channels[-1], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # middle self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_groups=self.norm_num_groups, num_attention_heads=None, dtype=self.dtype, ) # upsampling reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] up_blocks = [] for i, _ in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = FlaxUpDecoderBlock2D( in_channels=prev_output_channel, out_channels=output_channel, num_layers=self.layers_per_block + 1, resnet_groups=self.norm_num_groups, add_upsample=not is_final_block, dtype=self.dtype, ) up_blocks.append(up_block) prev_output_channel = output_channel self.up_blocks = up_blocks # end self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, sample, deterministic: bool = True): # z to block_in sample = self.conv_in(sample) # middle sample = self.mid_block(sample, deterministic=deterministic) # upsampling for block in self.up_blocks: sample = block(sample, deterministic=deterministic) sample = self.conv_norm_out(sample) sample = nn.swish(sample) sample = self.conv_out(sample) return sample class FlaxDiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): # Last axis to account for channels-last self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) self.logvar = jnp.clip(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = jnp.exp(0.5 * self.logvar) self.var = jnp.exp(self.logvar) if self.deterministic: self.var = self.std = jnp.zeros_like(self.mean) def sample(self, key): return self.mean + self.std * jax.random.normal(key, self.mean.shape) def kl(self, other=None): if self.deterministic: return jnp.array([0.0]) if other is None: return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) return 0.5 * jnp.sum( jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, axis=[1, 2, 3], ) def nll(self, sample, axis=[1, 2, 3]): if self.deterministic: return jnp.array([0.0]) logtwopi = jnp.log(2.0 * jnp.pi) return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) def mode(self): return self.mean @flax_register_to_config class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): r""" Flax implementation of a VAE model with KL loss for decoding latent representations. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its general usage and behavior. Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): Tuple of upsample block types. block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple of block output channels. layers_per_block (`int`, *optional*, defaults to `2`): Number of ResNet layer for each block. act_fn (`str`, *optional*, defaults to `silu`): The activation function to use. latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. sample_size (`int`, *optional*, defaults to 32): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): The `dtype` of the parameters. """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) up_block_types: Tuple[str] = ("UpDecoderBlock2D",) block_out_channels: Tuple[int] = (64,) layers_per_block: int = 1 act_fn: str = "silu" latent_channels: int = 4 norm_num_groups: int = 32 sample_size: int = 32 scaling_factor: float = 0.18215 dtype: jnp.dtype = jnp.float32 def setup(self): self.encoder = FlaxEncoder( in_channels=self.config.in_channels, out_channels=self.config.latent_channels, down_block_types=self.config.down_block_types, block_out_channels=self.config.block_out_channels, layers_per_block=self.config.layers_per_block, act_fn=self.config.act_fn, norm_num_groups=self.config.norm_num_groups, double_z=True, dtype=self.dtype, ) self.decoder = FlaxDecoder( in_channels=self.config.latent_channels, out_channels=self.config.out_channels, up_block_types=self.config.up_block_types, block_out_channels=self.config.block_out_channels, layers_per_block=self.config.layers_per_block, norm_num_groups=self.config.norm_num_groups, act_fn=self.config.act_fn, dtype=self.dtype, ) self.quant_conv = nn.Conv( 2 * self.config.latent_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.post_quant_conv = nn.Conv( self.config.latent_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} return self.init(rngs, sample)["params"] def encode(self, sample, deterministic: bool = True, return_dict: bool = True): sample = jnp.transpose(sample, (0, 2, 3, 1)) hidden_states = self.encoder(sample, deterministic=deterministic) moments = self.quant_conv(hidden_states) posterior = FlaxDiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) def decode(self, latents, deterministic: bool = True, return_dict: bool = True): if latents.shape[-1] != self.config.latent_channels: latents = jnp.transpose(latents, (0, 2, 3, 1)) hidden_states = self.post_quant_conv(latents) hidden_states = self.decoder(hidden_states, deterministic=deterministic) hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) if not return_dict: return (hidden_states,) return FlaxDecoderOutput(sample=hidden_states) def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) if sample_posterior: rng = self.make_rng("gaussian") hidden_states = posterior.latent_dist.sample(rng) else: hidden_states = posterior.latent_dist.mode() sample = self.decode(hidden_states, return_dict=return_dict).sample if not return_dict: return (sample,) return FlaxDecoderOutput(sample=sample) ================================================ FILE: 4DoF/diffusers/models/vq_model.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, apply_forward_hook from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer @dataclass class VQEncoderOutput(BaseOutput): """ Output of VQModel encoding method. Args: latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The encoded output sample from the last layer of the model. """ latents: torch.FloatTensor class VQModel(ModelMixin, ConfigMixin): r""" A VQ-VAE model for decoding latent representations. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. scaling_factor (`float`, *optional*, defaults to `0.18215`): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 3, sample_size: int = 32, num_vq_embeddings: int = 256, norm_num_groups: int = 32, vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, norm_type: str = "group", # group, spatial ): super().__init__() # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=False, ) vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) # pass init params to Decoder self.decoder = Decoder( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_type=norm_type, ) @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: h = self.encoder(x) h = self.quant_conv(h) if not return_dict: return (h,) return VQEncoderOutput(latents=h) @apply_forward_hook def decode( self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True ) -> Union[DecoderOutput, torch.FloatTensor]: # also go through quantization layer if not force_not_quantize: quant, emb_loss, info = self.quantize(h) else: quant = h quant2 = self.post_quant_conv(quant) dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) if not return_dict: return (dec,) return DecoderOutput(sample=dec) def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: r""" The [`VQModel`] forward method. Args: sample (`torch.FloatTensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. Returns: [`~models.vq_model.VQEncoderOutput`] or `tuple`: If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` is returned. """ x = sample h = self.encode(x).latents dec = self.decode(h).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec) ================================================ FILE: 4DoF/diffusers/optimization.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch optimization for diffusion models.""" import math from enum import Enum from typing import Optional, Union from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR from .utils import logging logger = logging.get_logger(__name__) class SchedulerType(Enum): LINEAR = "linear" COSINE = "cosine" COSINE_WITH_RESTARTS = "cosine_with_restarts" POLYNOMIAL = "polynomial" CONSTANT = "constant" CONSTANT_WITH_WARMUP = "constant_with_warmup" PIECEWISE_CONSTANT = "piecewise_constant" def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): """ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1): """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. step_rules (`string`): The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ rules_dict = {} rule_list = step_rules.split(",") for rule_str in rule_list[:-1]: value_str, steps_str = rule_str.split(":") steps = int(steps_str) value = float(value_str) rules_dict[steps] = value last_lr_multiple = float(rule_list[-1]) def create_rules_function(rules_dict, last_lr_multiple): def rule_func(steps: int) -> float: sorted_steps = sorted(rules_dict.keys()) for i, sorted_step in enumerate(sorted_steps): if steps < sorted_step: return rules_dict[sorted_steps[i]] return last_lr_multiple return rule_func rules_func = create_rules_function(rules_dict, last_lr_multiple) return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) ) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_cosine_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_periods (`float`, *optional*, defaults to 0.5): The number of periods of the cosine function in a schedule (the default is to just decrease from the max value to 0 following a half-cosine). last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_cosine_with_hard_restarts_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_cycles (`int`, *optional*, defaults to 1): The number of hard restarts to use. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) if progress >= 1.0: return 0.0 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_polynomial_decay_schedule_with_warmup( optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 ): """ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. lr_end (`float`, *optional*, defaults to 1e-7): The end LR. power (`float`, *optional*, defaults to 1.0): Power factor. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_init = optimizer.defaults["lr"] if not (lr_init > lr_end): raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step > num_training_steps: return lr_end / lr_init # as LambdaLR multiplies by lr_init else: lr_range = lr_init - lr_end decay_steps = num_training_steps - num_warmup_steps pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps decay = lr_range * pct_remaining**power + lr_end return decay / lr_init # as LambdaLR multiplies by lr_init return LambdaLR(optimizer, lr_lambda, last_epoch) TYPE_TO_SCHEDULER_FUNCTION = { SchedulerType.LINEAR: get_linear_schedule_with_warmup, SchedulerType.COSINE: get_cosine_schedule_with_warmup, SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, SchedulerType.CONSTANT: get_constant_schedule, SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule, } def get_scheduler( name: Union[str, SchedulerType], optimizer: Optimizer, step_rules: Optional[str] = None, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: int = 1, power: float = 1.0, last_epoch: int = -1, ): """ Unified API to get any scheduler from its name. Args: name (`str` or `SchedulerType`): The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. step_rules (`str`, *optional*): A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. num_warmup_steps (`int`, *optional*): The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_training_steps (`int``, *optional*): The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_cycles (`int`, *optional*): The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. power (`float`, *optional*, defaults to 1.0): Power factor. See `POLYNOMIAL` scheduler last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return schedule_func(optimizer, last_epoch=last_epoch) if name == SchedulerType.PIECEWISE_CONSTANT: return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch) # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") if name == SchedulerType.COSINE_WITH_RESTARTS: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, last_epoch=last_epoch, ) if name == SchedulerType.POLYNOMIAL: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, last_epoch=last_epoch, ) return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch ) ================================================ FILE: 4DoF/diffusers/pipeline_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: This file is deprecated and will be removed in a future version. # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 from .utils import deprecate deprecate( "pipelines_utils", "0.22.0", "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.", standard_warn=False, stacklevel=3, ) ================================================ FILE: 4DoF/diffusers/pipelines/__init__.py ================================================ from ..utils import ( OptionalDependencyNotAvailable, is_flax_available, is_invisible_watermark_available, is_k_diffusion_available, is_librosa_available, is_note_seq_available, is_onnx_available, is_torch_available, is_transformers_available, ) try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput from .pndm import PNDMPipeline from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline from .stochastic_karras_ve import KarrasVePipeline try: if not (is_torch_available() and is_librosa_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_librosa_objects import * # noqa F403 else: from .audio_diffusion import AudioDiffusionPipeline, Mel try: if not (is_torch_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .audioldm import AudioLDMPipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, ) from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, IFPipeline, IFSuperResolutionPipeline, ) from .kandinsky import ( KandinskyImg2ImgPipeline, KandinskyInpaintPipeline, KandinskyPipeline, KandinskyPriorPipeline, ) from .kandinsky2_2 import ( KandinskyV22ControlnetImg2ImgPipeline, KandinskyV22ControlnetPipeline, KandinskyV22Img2ImgPipeline, KandinskyV22InpaintPipeline, KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorPipeline, ) from .latent_diffusion import LDMTextToImagePipeline from .paint_by_example import PaintByExamplePipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_diffusion import ( CycleDiffusionPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, StableDiffusionLDM3DPipeline, StableDiffusionModelEditingPipeline, StableDiffusionPanoramaPipeline, StableDiffusionParadigmsPipeline, StableDiffusionPipeline, StableDiffusionPix2PixZeroPipeline, StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline try: if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 else: from .stable_diffusion_xl import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_onnx_objects import * # noqa F403 else: from .onnx_utils import OnnxRuntimeModel try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 else: from .stable_diffusion import ( OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionPipeline, OnnxStableDiffusionUpscalePipeline, StableDiffusionOnnxPipeline, ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 else: from .stable_diffusion import StableDiffusionKDiffusionPipeline try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_flax_objects import * # noqa F403 else: from .pipeline_flax_utils import FlaxDiffusionPipeline try: if not (is_flax_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 else: from .controlnet import FlaxStableDiffusionControlNetPipeline from .stable_diffusion import ( FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, ) try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline ================================================ FILE: 4DoF/diffusers/pipelines/alt_diffusion/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import BaseOutput, is_torch_available, is_transformers_available @dataclass # Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt class AltDiffusionPipelineOutput(BaseOutput): """ Output class for Alt Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] if is_transformers_available() and is_torch_available(): from .modeling_roberta_series import RobertaSeriesModelWithTransformation from .pipeline_alt_diffusion import AltDiffusionPipeline from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline ================================================ FILE: 4DoF/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py ================================================ from dataclasses import dataclass from typing import Optional, Tuple import torch from torch import nn from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel from transformers.utils import ModelOutput @dataclass class TransformationModelOutput(ModelOutput): """ Base class for text model's outputs that also contains a pooling of the last hidden states. Args: text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The text embeddings obtained by applying the projection layer to the pooler_output. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ projection_state: Optional[torch.FloatTensor] = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None class RobertaSeriesConfig(XLMRobertaConfig): def __init__( self, pad_token_id=1, bos_token_id=0, eos_token_id=2, project_dim=512, pooler_fn="cls", learn_encoder=False, use_attention_mask=True, **kwargs, ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.project_dim = project_dim self.pooler_fn = pooler_fn self.learn_encoder = learn_encoder self.use_attention_mask = use_attention_mask class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] base_model_prefix = "roberta" config_class = RobertaSeriesConfig def __init__(self, config): super().__init__(config) self.roberta = XLMRobertaModel(config) self.transformation = nn.Linear(config.hidden_size, config.project_dim) self.has_pre_transformation = getattr(config, "has_pre_transformation", False) if self.has_pre_transformation: self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ): r""" """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.base_model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=True if self.has_pre_transformation else output_hidden_states, return_dict=return_dict, ) if self.has_pre_transformation: sequence_output2 = outputs["hidden_states"][-2] sequence_output2 = self.pre_LN(sequence_output2) projection_state2 = self.transformation_pre(sequence_output2) return TransformationModelOutput( projection_state=projection_state2, last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) else: projection_state = self.transformation(outputs.last_hidden_state) return TransformationModelOutput( projection_state=projection_state, last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ================================================ FILE: 4DoF/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from packaging import version from transformers import CLIPImageProcessor, XLMRobertaTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AltDiffusionPipeline >>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap" >>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图" >>> image = pipe(prompt).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Alt Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`RobertaSeriesModelWithTransformation`]): Frozen text-encoder. Alt Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`XLMRobertaTokenizer`): Tokenizer of class [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: RobertaSeriesModelWithTransformation, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def decode_latents(self, latents): warnings.warn( ( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead" ), FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: Returns: [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, XLMRobertaTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import AltDiffusionImg2ImgPipeline >>> device = "cuda" >>> model_id_or_path = "BAAI/AltDiffusion-m9" >>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) >>> # "A fantasy landscape, trending on artstation" >>> prompt = "幻想风景, artstation" >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images >>> images[0].save("幻想风景.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionImg2ImgPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image to image generation using Alt Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`RobertaSeriesModelWithTransformation`]): Frozen text-encoder. Alt Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`XLMRobertaTokenizer`): Tokenizer of class [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: RobertaSeriesModelWithTransformation, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def decode_latents(self, latents): warnings.warn( ( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead" ), FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective" f" batch size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/audio_diffusion/__init__.py ================================================ from .mel import Mel from .pipeline_audio_diffusion import AudioDiffusionPipeline ================================================ FILE: 4DoF/diffusers/pipelines/audio_diffusion/mel.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np # noqa: E402 from ...configuration_utils import ConfigMixin, register_to_config from ...schedulers.scheduling_utils import SchedulerMixin try: import librosa # noqa: E402 _librosa_can_be_imported = True _import_error = "" except Exception as e: _librosa_can_be_imported = False _import_error = ( f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it." ) from PIL import Image # noqa: E402 class Mel(ConfigMixin, SchedulerMixin): """ Parameters: x_res (`int`): x resolution of spectrogram (time) y_res (`int`): y resolution of spectrogram (frequency bins) sample_rate (`int`): sample rate of audio n_fft (`int`): number of Fast Fourier Transforms hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res) top_db (`int`): loudest in decibels n_iter (`int`): number of iterations for Griffin Linn mel inversion """ config_name = "mel_config.json" @register_to_config def __init__( self, x_res: int = 256, y_res: int = 256, sample_rate: int = 22050, n_fft: int = 2048, hop_length: int = 512, top_db: int = 80, n_iter: int = 32, ): self.hop_length = hop_length self.sr = sample_rate self.n_fft = n_fft self.top_db = top_db self.n_iter = n_iter self.set_resolution(x_res, y_res) self.audio = None if not _librosa_can_be_imported: raise ValueError(_import_error) def set_resolution(self, x_res: int, y_res: int): """Set resolution. Args: x_res (`int`): x resolution of spectrogram (time) y_res (`int`): y resolution of spectrogram (frequency bins) """ self.x_res = x_res self.y_res = y_res self.n_mels = self.y_res self.slice_size = self.x_res * self.hop_length - 1 def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None): """Load audio. Args: audio_file (`str`): must be a file on disk due to Librosa limitation or raw_audio (`np.ndarray`): audio as numpy array """ if audio_file is not None: self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr) else: self.audio = raw_audio # Pad with silence if necessary. if len(self.audio) < self.x_res * self.hop_length: self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))]) def get_number_of_slices(self) -> int: """Get number of slices in audio. Returns: `int`: number of spectograms audio can be sliced into """ return len(self.audio) // self.slice_size def get_audio_slice(self, slice: int = 0) -> np.ndarray: """Get slice of audio. Args: slice (`int`): slice number of audio (out of get_number_of_slices()) Returns: `np.ndarray`: audio as numpy array """ return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)] def get_sample_rate(self) -> int: """Get sample rate: Returns: `int`: sample rate of audio """ return self.sr def audio_slice_to_image(self, slice: int) -> Image.Image: """Convert slice of audio to spectrogram. Args: slice (`int`): slice number of audio to convert (out of get_number_of_slices()) Returns: `PIL Image`: grayscale image of x_res x y_res """ S = librosa.feature.melspectrogram( y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels ) log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db) bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8) image = Image.fromarray(bytedata) return image def image_to_audio(self, image: Image.Image) -> np.ndarray: """Converts spectrogram to audio. Args: image (`PIL Image`): x_res x y_res grayscale image Returns: audio (`np.ndarray`): raw audio """ bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width)) log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db S = librosa.db_to_power(log_S) audio = librosa.feature.inverse.mel_to_audio( S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter ) return audio ================================================ FILE: 4DoF/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from math import acos, sin from typing import List, Tuple, Union import numpy as np import torch from PIL import Image from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, DDPMScheduler from ...utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput from .mel import Mel class AudioDiffusionPipeline(DiffusionPipeline): """ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqae ([`AutoencoderKL`]): Variational AutoEncoder for Latent Audio Diffusion or None unet ([`UNet2DConditionModel`]): UNET model mel ([`Mel`]): transform audio <-> spectrogram scheduler ([`DDIMScheduler` or `DDPMScheduler`]): de-noising scheduler """ _optional_components = ["vqvae"] def __init__( self, vqvae: AutoencoderKL, unet: UNet2DConditionModel, mel: Mel, scheduler: Union[DDIMScheduler, DDPMScheduler], ): super().__init__() self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae) def get_default_steps(self) -> int: """Returns default number of steps recommended for inference Returns: `int`: number of steps """ return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000 @torch.no_grad() def __call__( self, batch_size: int = 1, audio_file: str = None, raw_audio: np.ndarray = None, slice: int = 0, start_step: int = 0, steps: int = None, generator: torch.Generator = None, mask_start_secs: float = 0, mask_end_secs: float = 0, step_generator: torch.Generator = None, eta: float = 0, noise: torch.Tensor = None, encoding: torch.Tensor = None, return_dict=True, ) -> Union[ Union[AudioPipelineOutput, ImagePipelineOutput], Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]], ]: """Generate random mel spectrogram from audio input and convert to audio. Args: batch_size (`int`): number of samples to generate audio_file (`str`): must be a file on disk due to Librosa limitation or raw_audio (`np.ndarray`): audio as numpy array slice (`int`): slice number of audio to convert start_step (int): step to start from steps (`int`): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM) generator (`torch.Generator`): random number generator or None mask_start_secs (`float`): number of seconds of audio to mask (not generate) at start mask_end_secs (`float`): number of seconds of audio to mask (not generate) at end step_generator (`torch.Generator`): random number generator used to de-noise or None eta (`float`): parameter between 0 and 1 used with DDIM scheduler noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim) return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple Returns: `List[PIL Image]`: mel spectrograms (`float`, `List[np.ndarray]`): sample rate and raw audios """ steps = steps or self.get_default_steps() self.scheduler.set_timesteps(steps) step_generator = step_generator or generator # For backwards compatibility if type(self.unet.config.sample_size) == int: self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) if noise is None: noise = randn_tensor( ( batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1], ), generator=generator, device=self.device, ) images = noise mask = None if audio_file is not None or raw_audio is not None: self.mel.load_audio(audio_file, raw_audio) input_image = self.mel.audio_slice_to_image(slice) input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape( (input_image.height, input_image.width) ) input_image = (input_image / 255) * 2 - 1 input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device) if self.vqvae is not None: input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample( generator=generator )[0] input_images = self.vqvae.config.scaling_factor * input_images if start_step > 0: images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) pixels_per_second = ( self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length ) mask_start = int(mask_start_secs * pixels_per_second) mask_end = int(mask_end_secs * pixels_per_second) mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:])) for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])): if isinstance(self.unet, UNet2DConditionModel): model_output = self.unet(images, t, encoding)["sample"] else: model_output = self.unet(images, t)["sample"] if isinstance(self.scheduler, DDIMScheduler): images = self.scheduler.step( model_output=model_output, timestep=t, sample=images, eta=eta, generator=step_generator, )["prev_sample"] else: images = self.scheduler.step( model_output=model_output, timestep=t, sample=images, generator=step_generator, )["prev_sample"] if mask is not None: if mask_start > 0: images[:, :, :, :mask_start] = mask[:, step, :, :mask_start] if mask_end > 0: images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:] if self.vqvae is not None: # 0.18215 was scaling factor used in training to ensure unit variance images = 1 / self.vqvae.config.scaling_factor * images images = self.vqvae.decode(images)["sample"] images = (images / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).numpy() images = (images * 255).round().astype("uint8") images = list( (Image.fromarray(_[:, :, 0]) for _ in images) if images.shape[3] == 1 else (Image.fromarray(_, mode="RGB").convert("L") for _ in images) ) audios = [self.mel.image_to_audio(_) for _ in images] if not return_dict: return images, (self.mel.get_sample_rate(), audios) return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images)) @torch.no_grad() def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray: """Reverse step process: recover noisy image from generated image. Args: images (`List[PIL Image]`): list of images to encode steps (`int`): number of encoding steps to perform (defaults to 50) Returns: `np.ndarray`: noise tensor of shape (batch_size, 1, height, width) """ # Only works with DDIM as this method is deterministic assert isinstance(self.scheduler, DDIMScheduler) self.scheduler.set_timesteps(steps) sample = np.array( [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images] ) sample = (sample / 255) * 2 - 1 sample = torch.Tensor(sample).to(self.device) for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t model_output = self.unet(sample, t)["sample"] pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5) sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output return sample @staticmethod def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor: """Spherical Linear intERPolation Args: x0 (`torch.Tensor`): first tensor to interpolate between x1 (`torch.Tensor`): seconds tensor to interpolate between alpha (`float`): interpolation between 0 and 1 Returns: `torch.Tensor`: interpolated tensor """ theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1)) return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta) ================================================ FILE: 4DoF/diffusers/pipelines/audioldm/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( AudioLDMPipeline, ) else: from .pipeline_audioldm import AudioLDMPipeline ================================================ FILE: 4DoF/diffusers/pipelines/audioldm/pipeline_audioldm.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch import torch.nn.functional as F from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AudioLDMPipeline >>> pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "A hammer hitting a wooden surface" >>> audio = pipe(prompt).audio[0] ``` """ class AudioLDMPipeline(DiffusionPipeline): r""" Pipeline for text-to-audio generation using AudioLDM. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode audios to and from latent representations. text_encoder ([`ClapTextModelWithProjection`]): Frozen text-encoder. AudioLDM uses the text portion of [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap#transformers.ClapTextModelWithProjection), specifically the [RoBERTa HSTAT-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. tokenizer ([`PreTrainedTokenizer`]): Tokenizer of class [RobertaTokenizer](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer). unet ([`UNet2DConditionModel`]): U-Net architecture to denoise the encoded audio latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. vocoder ([`SpeechT5HifiGan`]): Vocoder of class [SpeechT5HifiGan](https://huggingface.co/docs/transformers/main/en/model_doc/speecht5#transformers.SpeechT5HifiGan). """ def __init__( self, vae: AutoencoderKL, text_encoder: ClapTextModelWithProjection, tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, vocoder: SpeechT5HifiGan, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, vocoder=vocoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and vocoder have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.vocoder]: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_waveforms_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device (`torch.device`): torch device num_waveforms_per_prompt (`int`): number of waveforms that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the audio generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLAP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask.to(device), ) prompt_embeds = prompt_embeds.text_embeds # additional L_2 normalization over each hidden-state prompt_embeds = F.normalize(prompt_embeds, dim=-1) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) ( bs_embed, seq_len, ) = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_input.input_ids.to(device) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input_ids, attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds.text_embeds # additional L_2 normalization over each hidden-state negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents mel_spectrogram = self.vae.decode(latents).sample return mel_spectrogram def mel_spectrogram_to_waveform(self, mel_spectrogram): if mel_spectrogram.dim() == 4: mel_spectrogram = mel_spectrogram.squeeze(1) waveform = self.vocoder(mel_spectrogram) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 waveform = waveform.cpu().float() return waveform # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, audio_length_in_s, vocoder_upsample_factor, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor if audio_length_in_s < min_audio_length_in_s: raise ValueError( f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " f"is {audio_length_in_s}." ) if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: raise ValueError( f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " f"{self.vae_scale_factor}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, height // self.vae_scale_factor, self.vocoder.config.model_in_dim // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, audio_length_in_s: Optional[float] = None, num_inference_steps: int = 10, guidance_scale: float = 2.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_waveforms_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, output_type: Optional[str] = "np", ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the audio generation. If not defined, one has to pass `prompt_embeds`. instead. audio_length_in_s (`int`, *optional*, defaults to 5.12): The length of the generated audio sample in seconds. num_inference_steps (`int`, *optional*, defaults to 10): The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 2.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate audios that are closely linked to the text `prompt`, usually at the expense of lower sound quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the audio generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_waveforms_per_prompt (`int`, *optional*, defaults to 1): The number of waveforms to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). output_type (`str`, *optional*, defaults to `"np"`): The output format of the generate image. Choose between: - `"np"`: Return Numpy `np.ndarray` objects. - `"pt"`: Return PyTorch `torch.Tensor` objects. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated audios. """ # 0. Convert audio input length from seconds to spectrogram height vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate if audio_length_in_s is None: audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor height = int(audio_length_in_s / vocoder_upsample_factor) original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) if height % self.vae_scale_factor != 0: height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor logger.info( f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " f"denoising process." ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, audio_length_in_s, vocoder_upsample_factor, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_waveforms_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_waveforms_per_prompt, num_channels_latents, height, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=None, class_labels=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Post-processing mel_spectrogram = self.decode_latents(latents) audio = self.mel_spectrogram_to_waveform(mel_spectrogram) audio = audio[:, :original_waveform_length] if output_type == "np": audio = audio.numpy() if not return_dict: return (audio,) return AudioPipelineOutput(audios=audio) ================================================ FILE: 4DoF/diffusers/pipelines/consistency_models/__init__.py ================================================ from .pipeline_consistency_models import ConsistencyModelPipeline ================================================ FILE: 4DoF/diffusers/pipelines/consistency_models/pipeline_consistency_models.py ================================================ from typing import Callable, List, Optional, Union import torch from ...models import UNet2DModel from ...schedulers import CMStochasticIterativeScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import ConsistencyModelPipeline >>> device = "cuda" >>> # Load the cd_imagenet64_l2 checkpoint. >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2" >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe.to(device) >>> # Onestep Sampling >>> image = pipe(num_inference_steps=1).images[0] >>> image.save("cd_imagenet64_l2_onestep_sample.png") >>> # Onestep sampling, class-conditional image generation >>> # ImageNet-64 class label 145 corresponds to king penguins >>> image = pipe(num_inference_steps=1, class_labels=145).images[0] >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png") >>> # Multistep sampling, class-conditional image generation >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo: >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0] >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png") ``` """ class ConsistencyModelPipeline(DiffusionPipeline): r""" Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1]. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" https://arxiv.org/pdf/2303.01469 Args: unet ([`UNet2DModel`]): Unconditional or class-conditional U-Net architecture to denoise image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the image latents. Currently only compatible with [`CMStochasticIterativeScheduler`]. """ def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None: super().__init__() self.register_modules( unet=unet, scheduler=scheduler, ) self.safety_checker = None def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Follows diffusers.VaeImageProcessor.postprocess def postprocess_image(self, sample: torch.FloatTensor, output_type: str = "pil"): if output_type not in ["pt", "np", "pil"]: raise ValueError( f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" ) # Equivalent to diffusers.VaeImageProcessor.denormalize sample = (sample / 2 + 0.5).clamp(0, 1) if output_type == "pt": return sample # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "np": return sample # Output_type must be 'pil' sample = self.numpy_to_pil(sample) return sample def prepare_class_labels(self, batch_size, device, class_labels=None): if self.unet.config.num_class_embeds is not None: if isinstance(class_labels, list): class_labels = torch.tensor(class_labels, dtype=torch.int) elif isinstance(class_labels, int): assert batch_size == 1, "Batch size must be 1 if classes is an int" class_labels = torch.tensor([class_labels], dtype=torch.int) elif class_labels is None: # Randomly generate batch_size class labels # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) class_labels = class_labels.to(device) else: class_labels = None return class_labels def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): if num_inference_steps is None and timesteps is None: raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") if num_inference_steps is not None and timesteps is not None: logger.warning( f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;" " `timesteps` will be used over `num_inference_steps`." ) if latents is not None: expected_shape = (batch_size, 3, img_size, img_size) if latents.shape != expected_shape: raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, batch_size: int = 1, class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, num_inference_steps: int = 1, timesteps: List[int] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*): Optional class labels for conditioning class-conditional consistency models. Will not be used if the model is not class-conditional. num_inference_steps (`int`, *optional*, defaults to 1): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Prepare call parameters img_size = self.unet.config.sample_size device = self._execution_device # 1. Check inputs self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps) # 2. Prepare image latents # Sample image latents x_0 ~ N(0, sigma_0^2 * I) sample = self.prepare_latents( batch_size=batch_size, num_channels=self.unet.config.in_channels, height=img_size, width=img_size, dtype=self.unet.dtype, device=device, generator=generator, latents=latents, ) # 3. Handle class_labels for class-conditional models class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps # 5. Denoising loop # Multistep sampling: implements Algorithm 1 in the paper with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): scaled_sample = self.scheduler.scale_model_input(sample, t) model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0] sample = self.scheduler.step(model_output, t, sample, generator=generator)[0] # call the callback, if provided progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, sample) # 6. Post-process image sample image = self.postprocess_image(sample, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/controlnet/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_flax_available, is_torch_available, is_transformers_available, ) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .multicontrolnet import MultiControlNetModel from .pipeline_controlnet import StableDiffusionControlNetPipeline from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline if is_transformers_available() and is_flax_available(): from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline ================================================ FILE: 4DoF/diffusers/pipelines/controlnet/multicontrolnet.py ================================================ import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn from ...models.controlnet import ControlNetModel, ControlNetOutput from ...models.modeling_utils import ModelMixin from ...utils import logging logger = logging.get_logger(__name__) class MultiControlNetModel(ModelMixin): r""" Multiple `ControlNetModel` wrapper class for Multi-ControlNet This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be compatible with `ControlNetModel`. Args: controlnets (`List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. You must set multiple `ControlNetModel` as a list. """ def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): super().__init__() self.nets = nn.ModuleList(controlnets) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: List[torch.tensor], conditioning_scale: List[float], class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): down_samples, mid_sample = controlnet( sample, timestep, encoder_hidden_states, image, scale, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, guess_mode, return_dict, ) # merge samples if i == 0: down_block_res_samples, mid_block_res_sample = down_samples, mid_sample else: down_block_res_samples = [ samples_prev + samples_curr for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) ] mid_block_res_sample += mid_sample return down_block_res_samples, mid_block_res_sample def save_pretrained( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Callable = None, safe_serialization: bool = False, variant: Optional[str] = None, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replace `torch.save` by another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). variant (`str`, *optional*): If specified, weights are saved in the format pytorch_model..bin. """ idx = 0 model_path_to_save = save_directory for controlnet in self.nets: controlnet.save_pretrained( model_path_to_save, is_main_process=is_main_process, save_function=save_function, safe_serialization=safe_serialization, variant=variant, ) idx += 1 model_path_to_save = model_path_to_save + f"_{idx}" @classmethod def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train the model, you should first set it back in training mode with `model.train()`. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_path (`os.PathLike`): A path to a *directory* containing model weights saved using [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., `./my_model_directory/controlnet`. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype will be automatically derived from the model's weights. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. variant (`str`, *optional*): If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is ignored when using `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. """ idx = 0 controlnets = [] # load controlnet and append to list until no controlnet directory exists anymore # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... model_path_to_load = pretrained_model_path while os.path.isdir(model_path_to_load): controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) controlnets.append(controlnet) idx += 1 model_path_to_load = pretrained_model_path + f"_{idx}" logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") if len(controlnets) == 0: raise ValueError( f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." ) return cls(controlnets) ================================================ FILE: 4DoF/diffusers/pipelines/controlnet/pipeline_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, is_compiled_module, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" ... ) >>> image = np.array(image) >>> # get canny image >>> image = cv2.Canny(image, 100, 200) >>> image = image[:, :, None] >>> image = np.concatenate([image, image, image], axis=2) >>> canny_image = Image.fromarray(image) >>> # load control net and stable diffusion v1-5 >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> # speed up diffusion process with faster scheduler and memory optimization >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) >>> # remove following line if xformers is not installed >>> pipe.enable_xformers_memory_efficient_attention() >>> pipe.enable_model_cpu_offload() >>> # generate image >>> generator = torch.manual_seed(0) >>> image = pipe( ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image ... ).images[0] ``` """ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: # the safety checker can offload the vae again _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # control net hook has be manually offloaded as it alternates with unet cpu_offload_with_hook(self.controlnet, device) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the controlnet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the controlnet stops applying. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_end ] # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: cond_scale = controlnet_conditioning_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, is_accelerate_available, is_accelerate_version, is_compiled_module, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" ... ) >>> np_image = np.array(image) >>> # get canny image >>> np_image = cv2.Canny(np_image, 100, 200) >>> np_image = np_image[:, :, None] >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2) >>> canny_image = Image.fromarray(np_image) >>> # load control net and stable diffusion v1-5 >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> # speed up diffusion process with faster scheduler and memory optimization >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # generate image >>> generator = torch.manual_seed(0) >>> image = pipe( ... "futuristic-looking woman", ... num_inference_steps=20, ... generator=generator, ... image=image, ... control_image=canny_image, ... ).images[0] ``` """ def prepare_image(image): if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: image = image.unsqueeze(0) image = image.to(dtype=torch.float32) else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 return image class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: # the safety checker can offload the vae again _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # control net hook has be manually offloaded as it alternates with unet cpu_offload_with_hook(self.controlnet, device) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_control_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, control_image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The initial image will be used as the starting point for the image generation process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting than for [`~StableDiffusionControlNetPipeline.__call__`]. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the controlnet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the controlnet stops applying. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_end ] # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, control_image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare image image = self.image_processor.preprocess(image).to(dtype=torch.float32) # 5. Prepare controlnet_conditioning_image if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: cond_scale = controlnet_conditioning_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, is_compiled_module, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install transformers accelerate >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> init_image = load_image( ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" ... ) >>> init_image = init_image.resize((512, 512)) >>> generator = torch.Generator(device="cpu").manual_seed(1) >>> mask_image = load_image( ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" ... ) >>> mask_image = mask_image.resize((512, 512)) >>> def make_inpaint_condition(image, image_mask): ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" ... image[image_mask > 0.5] = -1.0 # set as masked pixel ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) ... image = torch.from_numpy(image) ... return image >>> control_image = make_inpaint_condition(init_image, mask_image) >>> controlnet = ControlNetModel.from_pretrained( ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # generate image >>> image = pipe( ... "a handsome man with ray-ban sunglasses", ... num_inference_steps=20, ... generator=generator, ... eta=1.0, ... image=init_image, ... mask_image=mask_image, ... control_image=control_image, ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if image is None: raise ValueError("`image` input cannot be undefined.") if mask is None: raise ValueError("`mask_image` input cannot be undefined.") if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) masked_image = image * (mask < 0.5) # n.b. ensure backwards compatibility as old function does not return image if return_image: return mask, masked_image, image return mask, masked_image class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) as well as default text-to-image stable diffusion checkpoints, such as [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: # the safety checker can offload the vae again _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # control net hook has be manually offloaded as it alternates with unet cpu_offload_with_hook(self.controlnet, device) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def check_inputs( self, prompt, image, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_control_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None, timestep=None, is_strength_max=True, return_noise=False, return_image_latents=False, ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if (image is None or timestep is None) and not is_strength_max: raise ValueError( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) if return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents else: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma outputs = (latents,) if return_noise: outputs += (noise,) if return_image_latents: outputs += (image_latents,) return outputs def _default_height_width(self, height, width, image): # NOTE: It is possible that a list of images have different # dimensions for each image, so just checking the first image # is not _exactly_ correct, but it is simple. while isinstance(image, list): image = image[0] if height is None: if isinstance(image, PIL.Image.Image): height = image.height elif isinstance(image, torch.Tensor): height = image.shape[2] height = (height // 8) * 8 # round down to nearest multiple of 8 if width is None: if isinstance(image, PIL.Image.Image): width = image.width elif isinstance(image, torch.Tensor): width = image.shape[3] width = (width // 8) * 8 # round down to nearest multiple of 8 return height, width # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = self.vae.config.scaling_factor * image_latents return image_latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[torch.Tensor, PIL.Image.Image] = None, mask_image: Union[torch.Tensor, PIL.Image.Image] = None, control_image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.5, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. strength (`float`, *optional*, defaults to 1.): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked portion of the reference `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting than for [`~StableDiffusionControlNetPipeline.__call__`]. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the controlnet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the controlnet stops applying. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # 0. Default height and width to unet height, width = self._default_height_width(height, width, image) # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_end ] # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, control_image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images else: assert False # 4. Preprocess mask and image - resizes image and mask w.r.t height and width mask, masked_image, init_image = prepare_mask_and_masked_image( image, mask_image, height, width, return_image=True ) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device ) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, do_classifier_free_guidance, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: cond_scale = controlnet_conditioning_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual if num_channels_unet == 9: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: init_latents_proper = image_latents[:1] init_mask = mask[:1] if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from PIL import Image from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from ..stable_diffusion import FlaxStableDiffusionPipelineOutput from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> import jax.numpy as jnp >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> from diffusers.utils import load_image >>> from PIL import Image >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel >>> def image_grid(imgs, rows, cols): ... w, h = imgs[0].size ... grid = Image.new("RGB", size=(cols * w, rows * h)) ... for i, img in enumerate(imgs): ... grid.paste(img, box=(i % cols * w, i // cols * h)) ... return grid >>> def create_key(seed=0): ... return jax.random.PRNGKey(seed) >>> rng = create_key(0) >>> # get canny image >>> canny_image = load_image( ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" ... ) >>> prompts = "best quality, extremely detailed" >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" >>> # load control net and stable diffusion v1-5 >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 ... ) >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32 ... ) >>> params["controlnet"] = controlnet_params >>> num_samples = jax.device_count() >>> rng = jax.random.split(rng, jax.device_count()) >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) >>> p_params = replicate(params) >>> prompt_ids = shard(prompt_ids) >>> negative_prompt_ids = shard(negative_prompt_ids) >>> processed_image = shard(processed_image) >>> output = pipe( ... prompt_ids=prompt_ids, ... image=processed_image, ... params=p_params, ... prng_seed=rng, ... num_inference_steps=50, ... neg_prompt_ids=negative_prompt_ids, ... jit=True, ... ).images >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) >>> output_images = image_grid(output_images, num_samples // 4, 4) >>> output_images.save("generated_image.png") ``` """ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`FlaxCLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`FlaxControlNetModel`]: Provides additional conditioning to the unet during the denoising process. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, controlnet: FlaxControlNetModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warn( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_text_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): if not isinstance(image, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(image, Image.Image): image = [image] processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) return processed_images def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def _generate( self, prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int, guidance_scale: float, latents: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.array] = None, controlnet_conditioning_scale: float = 1.0, ): height, width = image.shape[-2:] if height % 64 != 0 or width % 64 != 0: raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) image = jnp.concatenate([image] * 2) latents_shape = ( batch_size, self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) down_block_res_samples, mid_block_res_sample = self.controlnet.apply( {"params": params["controlnet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, return_dict=False, ) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int = 50, guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, neg_prompt_ids: jnp.array = None, controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, return_dict: bool = True, jit: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt_ids (`jnp.array`): The prompt or prompts to guide the image generation. image (`jnp.array`): Array representing the ControlNet input condition. ControlNet use this input condition to generate guidance to Unet. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. latents (`jnp.array`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ height, width = image.shape[-2:] if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] if isinstance(controlnet_conditioning_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] if jit: images = _p_generate( self, prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ) else: images = self._generate( prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.array(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, num_inference_steps. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), static_broadcasted_argnums=(0, 5), ) def _p_generate( pipe, prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ): return pipe._generate( prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) def preprocess(image, dtype): image = image.convert("RGB") w, h = image.size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) return image ================================================ FILE: 4DoF/diffusers/pipelines/dance_diffusion/__init__.py ================================================ from .pipeline_dance_diffusion import DanceDiffusionPipeline ================================================ FILE: 4DoF/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...utils import logging, randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name class DanceDiffusionPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of [`IPNDMScheduler`]. """ def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 100, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, audio_length_in_s: Optional[float] = None, return_dict: bool = True, ) -> Union[AudioPipelineOutput, Tuple]: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of audio samples to generate. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at the expense of slower inference. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* `sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if audio_length_in_s is None: audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate sample_size = audio_length_in_s * self.unet.config.sample_rate down_scale_factor = 2 ** len(self.unet.up_blocks) if sample_size < 3 * down_scale_factor: raise ValueError( f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" f" {3 * down_scale_factor / self.unet.config.sample_rate}." ) original_sample_size = int(sample_size) if sample_size % down_scale_factor != 0: sample_size = ( (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1 ) * down_scale_factor logger.info( f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled" f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising" " process." ) sample_size = int(sample_size) dtype = next(iter(self.unet.parameters())).dtype shape = (batch_size, self.unet.config.in_channels, sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype) # set step values self.scheduler.set_timesteps(num_inference_steps, device=audio.device) self.scheduler.timesteps = self.scheduler.timesteps.to(dtype) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(audio, t).sample # 2. compute previous image: x_t -> t_t-1 audio = self.scheduler.step(model_output, t, audio).prev_sample audio = audio.clamp(-1, 1).float().cpu().numpy() audio = audio[:, :, :original_sample_size] if not return_dict: return (audio,) return AudioPipelineOutput(audios=audio) ================================================ FILE: 4DoF/diffusers/pipelines/ddim/__init__.py ================================================ from .pipeline_ddim import DDIMPipeline ================================================ FILE: 4DoF/diffusers/pipelines/ddim/pipeline_ddim.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...schedulers import DDIMScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DDIMPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of [`DDPMScheduler`], or [`DDIMScheduler`]. """ def __init__(self, unet, scheduler): super().__init__() # make sure scheduler can always be converted to DDIM scheduler = DDIMScheduler.from_config(scheduler.config) self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, eta: float = 0.0, num_inference_steps: int = 50, use_clipped_model_output: Optional[bool] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. eta (`float`, *optional*, defaults to 0.0): The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. use_clipped_model_output (`bool`, *optional*, defaults to `None`): if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed downstream to the scheduler. So use `None` for schedulers which don't support this argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # Sample gaussian noise to begin loop if isinstance(self.unet.config.sample_size, int): image_shape = ( batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size, ) else: image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) # set step values self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t).sample # 2. predict previous mean of image x_t-1 and add variance depending on eta # eta corresponds to η in paper and should be between [0, 1] # do x_t -> x_t-1 image = self.scheduler.step( model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator ).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/ddpm/__init__.py ================================================ from .pipeline_ddpm import DDPMPipeline ================================================ FILE: 4DoF/diffusers/pipelines/ddpm/pipeline_ddpm.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DDPMPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of [`DDPMScheduler`], or [`DDIMScheduler`]. """ def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, num_inference_steps: int = 1000, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 1000): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # Sample gaussian noise to begin loop if isinstance(self.unet.config.sample_size, int): image_shape = ( batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size, ) else: image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) if self.device.type == "mps": # randn does not work reproducibly on mps image = randn_tensor(image_shape, generator=generator) image = image.to(self.device) else: image = randn_tensor(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/deepfloyd_if/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available from .timesteps import ( fast27_timesteps, smart27_timesteps, smart50_timesteps, smart100_timesteps, smart185_timesteps, super27_timesteps, super40_timesteps, super100_timesteps, ) @dataclass class IFPipelineOutput(BaseOutput): """ Args: Output class for Stable Diffusion pipelines. images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content or a watermark. `None` if safety checking could not be performed. watermark_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_detected: Optional[List[bool]] watermark_detected: Optional[List[bool]] try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_if import IFPipeline from .pipeline_if_img2img import IFImg2ImgPipeline from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline from .pipeline_if_inpainting import IFInpaintingPipeline from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline from .pipeline_if_superresolution import IFSuperResolutionPipeline from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker ================================================ FILE: 4DoF/diffusers/pipelines/deepfloyd_if/pipeline_if.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) >>> pipe.enable_model_cpu_offload() >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt" ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> safety_modules = { ... "feature_extractor": pipe.feature_extractor, ... "safety_checker": pipe.safety_checker, ... "watermarker": pipe.watermarker, ... } >>> super_res_2_pipe = DiffusionPipeline.from_pretrained( ... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 ... ) >>> super_res_2_pipe.enable_model_cpu_offload() >>> image = super_res_2_pipe( ... prompt=prompt, ... image=image, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler intermediate_images = intermediate_images * self.scheduler.init_noise_sigma return intermediate_images def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, num_inference_steps: int = 100, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, width: Optional[int] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated image. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters height = height or self.unet.config.sample_size width = width or self.unet.config.sample_size if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare intermediate images intermediate_images = self.prepare_intermediate_images( batch_size * num_images_per_prompt, self.unet.config.in_channels, height, width, prompt_embeds.dtype, device, generator, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = ( torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images ) model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL image = self.numpy_to_pil(image) # 11. Apply watermark if self.watermarker is not None: image = self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 4DoF/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image.resize((768, 512)) >>> pipe = IFImg2ImgPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "A fantasy landscape in style minecraft" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", ... text_encoder=None, ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None ): _, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) image = self.scheduler.add_noise(image, noise, timestep) return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 0.7, num_inference_steps: int = 80, timesteps: List[int] = None, guidance_scale: float = 10.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. Prepare intermediate images image = self.preprocess_image(image) image = image.to(device=device, dtype=dtype) noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = ( torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images ) model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL image = self.numpy_to_pil(image) # 11. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 4DoF/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image.resize((768, 512)) >>> pipe = IFImg2ImgPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "A fantasy landscape in style minecraft" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", ... text_encoder=None, ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler image_noising_scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, image_noising_scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if unet.config.in_channels != 6: logger.warn( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, image_noising_scheduler=image_noising_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, original_image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # image if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # original_image if isinstance(original_image, list): check_image_type = original_image[0] else: check_image_type = original_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(original_image, list): image_batch_size = len(original_image) elif isinstance(original_image, torch.Tensor): image_batch_size = original_image.shape[0] elif isinstance(original_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(original_image, np.ndarray): image_batch_size = original_image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError( f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: if not isinstance(image, torch.Tensor) and not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: image = image[0] image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image, list) and isinstance(image[0], torch.Tensor): dims = image[0].ndim if dims == 3: image = torch.stack(image, dim=0) elif dims == 4: image = torch.concat(image, dim=0) else: raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") image = image.to(device=device, dtype=self.unet.dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.prepare_intermediate_images def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None ): _, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) image = self.scheduler.add_noise(image, noise, timestep) return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor], original_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 0.8, prompt: Union[str, List[str]] = None, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 250, clean_caption: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. original_image (`torch.FloatTensor` or `PIL.Image.Image`): The original image that `image` was varied from. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to 250): The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, original_image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 device = self._execution_device # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. prepare original image original_image = self.preprocess_original_image(original_image) original_image = original_image.to(device=device, dtype=dtype) # 6. Prepare intermediate images noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( original_image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator, ) # 7. Prepare upscaled image and noise level _, _, height, width = original_image.shape image = self.preprocess_image(image, num_images_per_prompt, device) upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) if do_classifier_free_guidance: noise_level = torch.cat([noise_level] * 2) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = torch.cat([intermediate_images, upscaled], dim=1) model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 12. Convert to PIL image = self.numpy_to_pil(image) # 13. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 4DoF/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" >>> response = requests.get(url) >>> mask_image = Image.open(BytesIO(response.content)) >>> mask_image = mask_image >>> pipe = IFInpaintingPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "blue sunglasses" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... mask_image=mask_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... mask_image=mask_image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, mask_image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # image if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # mask_image if isinstance(mask_image, list): check_image_type = mask_image[0] else: check_image_type = mask_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(mask_image, list): image_batch_size = len(mask_image) elif isinstance(mask_image, torch.Tensor): image_batch_size = mask_image.shape[0] elif isinstance(mask_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(mask_image, np.ndarray): image_batch_size = mask_image.shape[0] else: assert False if image_batch_size != 1 and batch_size != image_batch_size: raise ValueError( f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image def preprocess_mask_image(self, mask_image) -> torch.Tensor: if not isinstance(mask_image, list): mask_image = [mask_image] if isinstance(mask_image[0], torch.Tensor): mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) if mask_image.ndim == 2: # Batch and add channel dim for single mask mask_image = mask_image.unsqueeze(0).unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] == 1: # Single mask, the 0'th dimension is considered to be # the existing batch size of 1 mask_image = mask_image.unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] != 1: # Batch of mask, the 0'th dimension is considered to be # the batching dimension mask_image = mask_image.unsqueeze(1) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 elif isinstance(mask_image[0], PIL.Image.Image): new_mask_image = [] for mask_image_ in mask_image: mask_image_ = mask_image_.convert("L") mask_image_ = resize(mask_image_, self.unet.sample_size) mask_image_ = np.array(mask_image_) mask_image_ = mask_image_[None, None, :] new_mask_image.append(mask_image_) mask_image = new_mask_image mask_image = np.concatenate(mask_image, axis=0) mask_image = mask_image.astype(np.float32) / 255.0 mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) elif isinstance(mask_image[0], np.ndarray): mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) return mask_image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None ): image_batch_size, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) noised_image = self.scheduler.add_noise(image, noise, timestep) image = (1 - mask_image) * image + mask_image * noised_image return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, mask_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 1.0, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, mask_image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. Prepare intermediate images image = self.preprocess_image(image) image = image.to(device=device, dtype=dtype) mask_image = self.preprocess_mask_image(mask_image) mask_image = mask_image.to(device=device, dtype=dtype) if mask_image.shape[0] == 1: mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) else: mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = ( torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images ) model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 prev_intermediate_images = intermediate_images intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL image = self.numpy_to_pil(image) # 11. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 4DoF/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" >>> response = requests.get(url) >>> mask_image = Image.open(BytesIO(response.content)) >>> mask_image = mask_image >>> pipe = IFInpaintingPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "blue sunglasses" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... mask_image=mask_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... mask_image=mask_image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler image_noising_scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, image_noising_scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if unet.config.in_channels != 6: logger.warn( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, image_noising_scheduler=image_noising_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, original_image, mask_image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # image if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # original_image if isinstance(original_image, list): check_image_type = original_image[0] else: check_image_type = original_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(original_image, list): image_batch_size = len(original_image) elif isinstance(original_image, torch.Tensor): image_batch_size = original_image.shape[0] elif isinstance(original_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(original_image, np.ndarray): image_batch_size = original_image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError( f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" ) # mask_image if isinstance(mask_image, list): check_image_type = mask_image[0] else: check_image_type = mask_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(mask_image, list): image_batch_size = len(mask_image) elif isinstance(mask_image, torch.Tensor): image_batch_size = mask_image.shape[0] elif isinstance(mask_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(mask_image, np.ndarray): image_batch_size = mask_image.shape[0] else: assert False if image_batch_size != 1 and batch_size != image_batch_size: raise ValueError( f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: if not isinstance(image, torch.Tensor) and not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: image = image[0] image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image, list) and isinstance(image[0], torch.Tensor): dims = image[0].ndim if dims == 3: image = torch.stack(image, dim=0) elif dims == 4: image = torch.concat(image, dim=0) else: raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") image = image.to(device=device, dtype=self.unet.dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.preprocess_mask_image def preprocess_mask_image(self, mask_image) -> torch.Tensor: if not isinstance(mask_image, list): mask_image = [mask_image] if isinstance(mask_image[0], torch.Tensor): mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) if mask_image.ndim == 2: # Batch and add channel dim for single mask mask_image = mask_image.unsqueeze(0).unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] == 1: # Single mask, the 0'th dimension is considered to be # the existing batch size of 1 mask_image = mask_image.unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] != 1: # Batch of mask, the 0'th dimension is considered to be # the batching dimension mask_image = mask_image.unsqueeze(1) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 elif isinstance(mask_image[0], PIL.Image.Image): new_mask_image = [] for mask_image_ in mask_image: mask_image_ = mask_image_.convert("L") mask_image_ = resize(mask_image_, self.unet.sample_size) mask_image_ = np.array(mask_image_) mask_image_ = mask_image_[None, None, :] new_mask_image.append(mask_image_) mask_image = new_mask_image mask_image = np.concatenate(mask_image, axis=0) mask_image = mask_image.astype(np.float32) / 255.0 mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) elif isinstance(mask_image[0], np.ndarray): mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) return mask_image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.prepare_intermediate_images def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None ): image_batch_size, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) noised_image = self.scheduler.add_noise(image, noise, timestep) image = (1 - mask_image) * image + mask_image * noised_image return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor], original_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, mask_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 0.8, prompt: Union[str, List[str]] = None, num_inference_steps: int = 100, timesteps: List[int] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 0, clean_caption: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. original_image (`torch.FloatTensor` or `PIL.Image.Image`): The original image that `image` was varied from. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to 0): The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, original_image, mask_image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 device = self._execution_device # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. prepare original image original_image = self.preprocess_original_image(original_image) original_image = original_image.to(device=device, dtype=dtype) # 6. prepare mask image mask_image = self.preprocess_mask_image(mask_image) mask_image = mask_image.to(device=device, dtype=dtype) if mask_image.shape[0] == 1: mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) else: mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) # 6. Prepare intermediate images noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( original_image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator, ) # 7. Prepare upscaled image and noise level _, _, height, width = original_image.shape image = self.preprocess_image(image, num_images_per_prompt, device) upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) if do_classifier_free_guidance: noise_level = torch.cat([noise_level] * 2) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = torch.cat([intermediate_images, upscaled], dim=1) model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 prev_intermediate_images = intermediate_images intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 12. Convert to PIL image = self.numpy_to_pil(image) # 13. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 4DoF/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) >>> pipe.enable_model_cpu_offload() >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler image_noising_scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, image_noising_scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if unet.config.in_channels != 6: logger.warn( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, image_noising_scheduler=image_noising_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, batch_size, noise_level, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})" ) if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler intermediate_images = intermediate_images * self.scheduler.init_noise_sigma return intermediate_images def preprocess_image(self, image, num_images_per_prompt, device): if not isinstance(image, torch.Tensor) and not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: image = image[0] image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image, list) and isinstance(image[0], torch.Tensor): dims = image[0].ndim if dims == 3: image = torch.stack(image, dim=0) elif dims == 4: image = torch.concat(image, dim=0) else: raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") image = image.to(device=device, dtype=self.unet.dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: int = None, width: int = None, image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 250, clean_caption: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated image. image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`): The image to be upscaled. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to 250): The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, batch_size, noise_level, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters height = height or self.unet.config.sample_size width = width or self.unet.config.sample_size device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare intermediate images num_channels = self.unet.config.in_channels // 2 intermediate_images = self.prepare_intermediate_images( batch_size * num_images_per_prompt, num_channels, height, width, prompt_embeds.dtype, device, generator, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare upscaled image and noise level image = self.preprocess_image(image, num_images_per_prompt, device) upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) if do_classifier_free_guidance: noise_level = torch.cat([noise_level] * 2) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = torch.cat([intermediate_images, upscaled], dim=1) model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 9. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 10. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 11. Convert to PIL image = self.numpy_to_pil(image) # 12. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 9. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 10. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 4DoF/diffusers/pipelines/deepfloyd_if/safety_checker.py ================================================ import numpy as np import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) class IFSafetyChecker(PreTrainedModel): config_class = CLIPConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPConfig): super().__init__(config) self.vision_model = CLIPVisionModelWithProjection(config.vision_config) self.p_head = nn.Linear(config.vision_config.projection_dim, 1) self.w_head = nn.Linear(config.vision_config.projection_dim, 1) @torch.no_grad() def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): image_embeds = self.vision_model(clip_input)[0] nsfw_detected = self.p_head(image_embeds) nsfw_detected = nsfw_detected.flatten() nsfw_detected = nsfw_detected > p_threshold nsfw_detected = nsfw_detected.tolist() if any(nsfw_detected): logger.warning( "Potential NSFW content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) for idx, nsfw_detected_ in enumerate(nsfw_detected): if nsfw_detected_: images[idx] = np.zeros(images[idx].shape) watermark_detected = self.w_head(image_embeds) watermark_detected = watermark_detected.flatten() watermark_detected = watermark_detected > w_threshold watermark_detected = watermark_detected.tolist() if any(watermark_detected): logger.warning( "Potential watermarked content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) for idx, watermark_detected_ in enumerate(watermark_detected): if watermark_detected_: images[idx] = np.zeros(images[idx].shape) return images, nsfw_detected, watermark_detected ================================================ FILE: 4DoF/diffusers/pipelines/deepfloyd_if/timesteps.py ================================================ fast27_timesteps = [ 999, 800, 799, 600, 599, 500, 400, 399, 377, 355, 333, 311, 288, 266, 244, 222, 200, 199, 177, 155, 133, 111, 88, 66, 44, 22, 0, ] smart27_timesteps = [ 999, 976, 952, 928, 905, 882, 858, 857, 810, 762, 715, 714, 572, 429, 428, 286, 285, 238, 190, 143, 142, 118, 95, 71, 47, 24, 0, ] smart50_timesteps = [ 999, 988, 977, 966, 955, 944, 933, 922, 911, 900, 899, 879, 859, 840, 820, 800, 799, 766, 733, 700, 699, 650, 600, 599, 500, 499, 400, 399, 350, 300, 299, 266, 233, 200, 199, 179, 159, 140, 120, 100, 99, 88, 77, 66, 55, 44, 33, 22, 11, 0, ] smart100_timesteps = [ 999, 995, 992, 989, 985, 981, 978, 975, 971, 967, 964, 961, 957, 956, 951, 947, 942, 937, 933, 928, 923, 919, 914, 913, 908, 903, 897, 892, 887, 881, 876, 871, 870, 864, 858, 852, 846, 840, 834, 828, 827, 820, 813, 806, 799, 792, 785, 784, 777, 770, 763, 756, 749, 742, 741, 733, 724, 716, 707, 699, 698, 688, 677, 666, 656, 655, 645, 634, 623, 613, 612, 598, 584, 570, 569, 555, 541, 527, 526, 505, 484, 483, 462, 440, 439, 396, 395, 352, 351, 308, 307, 264, 263, 220, 219, 176, 132, 88, 44, 0, ] smart185_timesteps = [ 999, 997, 995, 992, 990, 988, 986, 984, 981, 979, 977, 975, 972, 970, 968, 966, 964, 961, 959, 957, 956, 954, 951, 949, 946, 944, 941, 939, 936, 934, 931, 929, 926, 924, 921, 919, 916, 914, 913, 910, 907, 905, 902, 899, 896, 893, 891, 888, 885, 882, 879, 877, 874, 871, 870, 867, 864, 861, 858, 855, 852, 849, 846, 843, 840, 837, 834, 831, 828, 827, 824, 821, 817, 814, 811, 808, 804, 801, 798, 795, 791, 788, 785, 784, 780, 777, 774, 770, 766, 763, 760, 756, 752, 749, 746, 742, 741, 737, 733, 730, 726, 722, 718, 714, 710, 707, 703, 699, 698, 694, 690, 685, 681, 677, 673, 669, 664, 660, 656, 655, 650, 646, 641, 636, 632, 627, 622, 618, 613, 612, 607, 602, 596, 591, 586, 580, 575, 570, 569, 563, 557, 551, 545, 539, 533, 527, 526, 519, 512, 505, 498, 491, 484, 483, 474, 466, 457, 449, 440, 439, 428, 418, 407, 396, 395, 381, 366, 352, 351, 330, 308, 307, 286, 264, 263, 242, 220, 219, 176, 175, 132, 131, 88, 44, 0, ] super27_timesteps = [ 999, 991, 982, 974, 966, 958, 950, 941, 933, 925, 916, 908, 900, 899, 874, 850, 825, 800, 799, 700, 600, 500, 400, 300, 200, 100, 0, ] super40_timesteps = [ 999, 992, 985, 978, 971, 964, 957, 949, 942, 935, 928, 921, 914, 907, 900, 899, 879, 859, 840, 820, 800, 799, 766, 733, 700, 699, 650, 600, 599, 500, 499, 400, 399, 300, 299, 200, 199, 100, 99, 0, ] super100_timesteps = [ 999, 996, 992, 989, 985, 982, 979, 975, 972, 968, 965, 961, 958, 955, 951, 948, 944, 941, 938, 934, 931, 927, 924, 920, 917, 914, 910, 907, 903, 900, 899, 891, 884, 876, 869, 861, 853, 846, 838, 830, 823, 815, 808, 800, 799, 788, 777, 766, 755, 744, 733, 722, 711, 700, 699, 688, 677, 666, 655, 644, 633, 622, 611, 600, 599, 585, 571, 557, 542, 528, 514, 500, 499, 485, 471, 457, 442, 428, 414, 400, 399, 379, 359, 340, 320, 300, 299, 279, 259, 240, 220, 200, 199, 166, 133, 100, 99, 66, 33, 0, ] ================================================ FILE: 4DoF/diffusers/pipelines/deepfloyd_if/watermark.py ================================================ from typing import List import PIL import torch from PIL import Image from ...configuration_utils import ConfigMixin from ...models.modeling_utils import ModelMixin from ...utils import PIL_INTERPOLATION class IFWatermarker(ModelMixin, ConfigMixin): def __init__(self): super().__init__() self.register_buffer("watermark_image", torch.zeros((62, 62, 4))) self.watermark_image_as_pil = None def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None): # copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 h = images[0].height w = images[0].width sample_size = sample_size or h coef = min(h / sample_size, w / sample_size) img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) S1, S2 = 1024**2, img_w * img_h K = (S2 / S1) ** 0.5 wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) if self.watermark_image_as_pil is None: watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy() watermark_image = Image.fromarray(watermark_image, mode="RGBA") self.watermark_image_as_pil = watermark_image wm_img = self.watermark_image_as_pil.resize( (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None ) for pil_img in images: pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) return images ================================================ FILE: 4DoF/diffusers/pipelines/dit/__init__.py ================================================ from .pipeline_dit import DiTPipeline ================================================ FILE: 4DoF/diffusers/pipelines/dit/pipeline_dit.py ================================================ # Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) # William Peebles and Saining Xie # # Copyright (c) 2021 OpenAI # MIT License # # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Optional, Tuple, Union import torch from ...models import AutoencoderKL, Transformer2DModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DiTPipeline(DiffusionPipeline): r""" This pipeline inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: transformer ([`Transformer2DModel`]): Class conditioned Transformer in Diffusion model to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `dit` to denoise the encoded image latents. """ def __init__( self, transformer: Transformer2DModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, id2label: Optional[Dict[int, str]] = None, ): super().__init__() self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) # create a imagenet -> id dictionary for easier use self.labels = {} if id2label is not None: for key, value in id2label.items(): for label in value.split(","): self.labels[label.lstrip().rstrip()] = int(key) self.labels = dict(sorted(self.labels.items())) def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: r""" Map label strings, *e.g.* from ImageNet, to corresponding class ids. Parameters: label (`str` or `dict` of `str`): label strings to be mapped to class ids. Returns: `list` of `int`: Class ids to be processed by pipeline. """ if not isinstance(label, list): label = list(label) for l in label: if l not in self.labels: raise ValueError( f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." ) return [self.labels[l] for l in label] @torch.no_grad() def __call__( self, class_labels: List[int], guidance_scale: float = 4.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, num_inference_steps: int = 50, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" Function invoked when calling the pipeline for generation. Args: class_labels (List[int]): List of imagenet class labels for the images to be generated. guidance_scale (`float`, *optional*, defaults to 4.0): Scale of the guidance signal. generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 250): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. """ batch_size = len(class_labels) latent_size = self.transformer.config.sample_size latent_channels = self.transformer.config.in_channels latents = randn_tensor( shape=(batch_size, latent_channels, latent_size, latent_size), generator=generator, device=self.device, dtype=self.transformer.dtype, ) latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents class_labels = torch.tensor(class_labels, device=self.device).reshape(-1) class_null = torch.tensor([1000] * batch_size, device=self.device) class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels # set step values self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): if guidance_scale > 1: half = latent_model_input[: len(latent_model_input) // 2] latent_model_input = torch.cat([half, half], dim=0) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) timesteps = t if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" if isinstance(timesteps, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(latent_model_input.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( latent_model_input, timestep=timesteps, class_labels=class_labels_input ).sample # perform guidance if guidance_scale > 1: eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) noise_pred = torch.cat([eps, rest], dim=1) # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: model_output, _ = torch.split(noise_pred, latent_channels, dim=1) else: model_output = noise_pred # compute previous image: x_t -> x_t-1 latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample if guidance_scale > 1: latents, _ = latent_model_input.chunk(2, dim=0) else: latents = latent_model_input latents = 1 / self.vae.config.scaling_factor * latents samples = self.vae.decode(latents).sample samples = (samples / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": samples = self.numpy_to_pil(samples) if not return_dict: return (samples,) return ImagePipelineOutput(images=samples) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import KandinskyPipeline, KandinskyPriorPipeline else: from .pipeline_kandinsky import KandinskyPipeline from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput from .text_encoder import MultilingualCLIP ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky/pipeline_kandinsky.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import torch from transformers import ( XLMRobertaTokenizer, ) from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDIMScheduler, DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from .text_encoder import MultilingualCLIP logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/Kandinsky-2-1-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> out = pipe_prior(prompt) >>> image_emb = out.image_embeds >>> negative_image_emb = out.negative_image_embeds >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") >>> pipe.to("cuda") >>> image = pipe( ... prompt, ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images >>> image[0].save("cat.png") ``` """ def get_new_h_w(h, w, scale_factor=8): new_h = h // scale_factor**2 if h % scale_factor**2 != 0: new_h += 1 new_w = w // scale_factor**2 if w % scale_factor**2 != 0: new_w += 1 return new_h * scale_factor, new_w * scale_factor class KandinskyPipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, text_encoder: MultilingualCLIP, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, DDPMScheduler], movq: VQModel, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", truncation=True, max_length=77, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) text_mask = text_inputs.attention_mask.to(device) prompt_embeds, text_encoder_hidden_states = self.text_encoder( input_ids=text_input_ids, attention_mask=text_mask ) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) uncond_text_input_ids = uncond_input.input_ids.to(device) uncond_text_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.text_encoder, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=prompt_embeds.dtype, device=device ) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.unet.config.in_channels height, width = get_new_h_w(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import numpy as np import PIL import torch from PIL import Image from transformers import ( XLMRobertaTokenizer, ) from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDIMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from .text_encoder import MultilingualCLIP logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "A red cartoon frog, 4k" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyImg2ImgPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/frog.png" ... ) >>> image = pipe( ... prompt, ... image=init_image, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... strength=0.2, ... ).images >>> image[0].save("red_frog.png") ``` """ def get_new_h_w(h, w, scale_factor=8): new_h = h // scale_factor**2 if h % scale_factor**2 != 0: new_h += 1 new_w = w // scale_factor**2 if w % scale_factor**2 != 0: new_w += 1 return new_h * scale_factor, new_w * scale_factor def prepare_image(pil_image, w=512, h=512): pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) arr = np.array(pil_image.convert("RGB")) arr = arr.astype(np.float32) / 127.5 - 1 arr = np.transpose(arr, [2, 0, 1]) image = torch.from_numpy(arr).unsqueeze(0) return image class KandinskyImg2ImgPipeline(DiffusionPipeline): """ Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ image encoder and decoder """ def __init__( self, text_encoder: MultilingualCLIP, movq: VQModel, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def prepare_latents(self, latents, latent_timestep, shape, dtype, device, generator, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma shape = latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.add_noise(latents, noise, latent_timestep) return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) text_mask = text_inputs.attention_mask.to(device) prompt_embeds, text_encoder_hidden_states = self.text_encoder( input_ids=text_input_ids, attention_mask=text_mask ) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) uncond_text_input_ids = uncond_input.input_ids.to(device) uncond_text_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.text_encoder, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # add_noise method to overwrite the one in schedule because it use a different beta schedule for adding noise vs sampling def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: betas = torch.linspace(0.0001, 0.02, 1000, dtype=torch.float32) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], image_embeds: torch.FloatTensor, negative_image_embeds: torch.FloatTensor, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, strength: float = 0.3, guidance_scale: float = 7.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.FloatTensor`, `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. strength (`float`, *optional*, defaults to 0.3): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ # 1. Define call parameters if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 # 2. get text and image embeddings prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=prompt_embeds.dtype, device=device ) # 3. pre-processing initial image if not isinstance(image, list): image = [image] if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" ) image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = image.to(dtype=prompt_embeds.dtype, device=device) latents = self.movq.encode(image)["latents"] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # the formular to calculate timestep for add_noise is taken from the original kandinsky repo latent_timestep = int(self.scheduler.config.num_train_timesteps * strength) - 2 latent_timestep = torch.tensor([latent_timestep] * batch_size, dtype=timesteps_tensor.dtype, device=device) num_channels_latents = self.unet.config.in_channels height, width = get_new_h_w(height, width, self.movq_scale_factor) # 5. Create initial latent latents = self.prepare_latents( latents, latent_timestep, (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, self.scheduler, ) # 6. Denoising loop for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample # 7. post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from typing import List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from PIL import Image from transformers import ( XLMRobertaTokenizer, ) from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDIMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from .text_encoder import MultilingualCLIP logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> import numpy as np >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "a hat" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyInpaintPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> mask = np.ones((768, 768), dtype=np.float32) >>> mask[:250, 250:-250] = 0 >>> out = pipe( ... prompt, ... image=init_image, ... mask_image=mask, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ) >>> image = out.images[0] >>> image.save("cat_with_hat.png") ``` """ def get_new_h_w(h, w, scale_factor=8): new_h = h // scale_factor**2 if h % scale_factor**2 != 0: new_h += 1 new_w = w // scale_factor**2 if w % scale_factor**2 != 0: new_w += 1 return new_h * scale_factor, new_w * scale_factor def prepare_mask(masks): prepared_masks = [] for mask in masks: old_mask = deepcopy(mask) for i in range(mask.shape[1]): for j in range(mask.shape[2]): if old_mask[0][i][j] == 1: continue if i != 0: mask[:, i - 1, j] = 0 if j != 0: mask[:, i, j - 1] = 0 if i != 0 and j != 0: mask[:, i - 1, j - 1] = 0 if i != mask.shape[1] - 1: mask[:, i + 1, j] = 0 if j != mask.shape[2] - 1: mask[:, i, j + 1] = 0 if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: mask[:, i + 1, j + 1] = 0 prepared_masks.append(mask) return torch.stack(prepared_masks, dim=0) def prepare_mask_and_masked_image(image, mask, height, width): r""" Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if image is None: raise ValueError("`image` input cannot be undefined.") if mask is None: raise ValueError("`mask_image` input cannot be undefined.") if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) return mask, image class KandinskyInpaintPipeline(DiffusionPipeline): """ Pipeline for text-guided image inpainting using Kandinsky2.1 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ image encoder and decoder """ def __init__( self, text_encoder: MultilingualCLIP, movq: VQModel, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, ): super().__init__() self.register_modules( text_encoder=text_encoder, movq=movq, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) text_mask = text_inputs.attention_mask.to(device) prompt_embeds, text_encoder_hidden_states = self.text_encoder( input_ids=text_input_ids, attention_mask=text_mask ) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) uncond_text_input_ids = uncond_input.input_ids.to(device) uncond_text_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.text_encoder, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], image_embeds: torch.FloatTensor, negative_image_embeds: torch.FloatTensor, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.FloatTensor`, `PIL.Image.Image` or `np.ndarray`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. mask_image (`PIL.Image.Image`,`torch.FloatTensor` or `np.ndarray`): `Image`, or a tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. You can pass a pytorch tensor as mask only if the image you passed is a pytorch tensor, and it should contain one color channel (L) instead of 3, so the expected shape would be either `(B, 1, H, W,)`, `(B, H, W)`, `(1, H, W)` or `(H, W)` If image is an PIL image or numpy array, mask should also be a either PIL image or numpy array. If it is a PIL image, it will be converted to a single channel (luminance) before use. If it is a nummpy array, the expected shape is `(H, W)`. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ # Define call parameters if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=prompt_embeds.dtype, device=device ) # preprocess image and mask mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) image = image.to(dtype=prompt_embeds.dtype, device=device) image = self.movq.encode(image)["latents"] mask_image = mask_image.to(dtype=prompt_embeds.dtype, device=device) image_shape = tuple(image.shape[-2:]) mask_image = F.interpolate( mask_image, image_shape, mode="nearest", ) mask_image = prepare_mask(mask_image) masked_image = image * mask_image mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: mask_image = mask_image.repeat(2, 1, 1, 1) masked_image = masked_image.repeat(2, 1, 1, 1) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.movq.config.latent_channels # get h, w for latents sample_height, sample_width = get_new_h_w(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, sample_height, sample_width), text_encoder_hidden_states.dtype, device, generator, latents, self.scheduler, ) # Check that sizes of mask, masked image and latents match with expected num_channels_mask = mask_image.shape[1] num_channels_masked_image = masked_image.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( BaseOutput, is_accelerate_available, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> out = pipe_prior(prompt) >>> image_emb = out.image_embeds >>> negative_image_emb = out.negative_image_embeds >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") >>> pipe.to("cuda") >>> image = pipe( ... prompt, ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images >>> image[0].save("cat.png") ``` """ EXAMPLE_INTERPOLATE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline >>> from diffusers.utils import load_image >>> import PIL >>> import torch >>> from torchvision import transforms >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> img1 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/starry_night.jpeg" ... ) >>> images_texts = ["a cat", img1, img2] >>> weights = [0.3, 0.3, 0.4] >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) >>> pipe.to("cuda") >>> image = pipe( ... "", ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=150, ... ).images[0] >>> image.save("starry_cat.png") ``` """ @dataclass class KandinskyPriorPipelineOutput(BaseOutput): """ Output class for KandinskyPriorPipeline. Args: image_embeds (`torch.FloatTensor`) clip image embeddings for text prompt negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`) clip image embeddings for unconditional tokens """ image_embeds: Union[torch.FloatTensor, np.ndarray] negative_image_embeds: Union[torch.FloatTensor, np.ndarray] class KandinskyPriorPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: UnCLIPScheduler, image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, ) @torch.no_grad() @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) def interpolate( self, images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], weights: List[float], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, negative_prior_prompt: Optional[str] = None, negative_prompt: Union[str] = "", guidance_scale: float = 4.0, device=None, ): """ Function invoked when using the prior pipeline for interpolation. Args: images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): list of prompts and images to guide the image generation. weights: (`List[float]`): list of weights for each condition in `images_and_prompts` num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. negative_prior_prompt (`str`, *optional*): The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ device = device or self.device if len(images_and_prompts) != len(weights): raise ValueError( f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" ) image_embeddings = [] for cond, weight in zip(images_and_prompts, weights): if isinstance(cond, str): image_emb = self( cond, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ).image_embeds elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): if isinstance(cond, PIL.Image.Image): cond = ( self.image_processor(cond, return_tensors="pt") .pixel_values[0] .unsqueeze(0) .to(dtype=self.image_encoder.dtype, device=device) ) image_emb = self.image_encoder(cond)["image_embeds"] else: raise ValueError( f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" ) image_embeddings.append(image_emb * weight) image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True) out_zero = self( negative_prompt, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ) zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def get_zero_embed(self, batch_size=1, device=None): device = device or self.device zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( device=device, dtype=self.image_encoder.dtype ) zero_image_emb = self.image_encoder(zero_img)["image_embeds"] zero_image_emb = zero_image_emb.repeat(batch_size, 1) return zero_image_emb def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.image_encoder, self.text_encoder, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): return self.device for module in self.text_encoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, output_type: Optional[str] = "pt", # pt only return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] elif not isinstance(negative_prompt, list) and negative_prompt is not None: raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") # if the negative prompt is defined we double the batch size to # directly retrieve the negative prompt embedding if negative_prompt is not None: prompt = prompt + negative_prompt negative_prompt = 2 * negative_prompt device = self._execution_device batch_size = len(prompt) batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) prior_timesteps_tensor = self.scheduler.timesteps embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, embedding_dim), prompt_embeds.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == prior_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = prior_timesteps_tensor[i + 1] latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample latents = self.prior.post_process_latents(latents) image_embeddings = latents # if negative prompt has been defined, we retrieve split the image embedding into two if negative_prompt is None: zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) else: image_embeddings, zero_embeds = image_embeddings.chunk(2) if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": image_embeddings = image_embeddings.cpu().numpy() zero_embeds = zero_embeds.cpu().numpy() if not return_dict: return (image_embeddings, zero_embeds) return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky/text_encoder.py ================================================ import torch from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel class MCLIPConfig(XLMRobertaConfig): model_type = "M-CLIP" def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs): self.transformerDimensions = transformerDimSize self.numDims = imageDimSize super().__init__(**kwargs) class MultilingualCLIP(PreTrainedModel): config_class = MCLIPConfig def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) self.transformer = XLMRobertaModel(config) self.LinearTransformation = torch.nn.Linear( in_features=config.transformerDimensions, out_features=config.numDims ) def forward(self, input_ids, attention_mask): embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] return self.LinearTransformation(embs2), embs ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky2_2/__init__.py ================================================ from .pipeline_kandinsky2_2 import KandinskyV22Pipeline from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import torch from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline >>> import torch >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> out = pipe_prior(prompt) >>> image_emb = out.image_embeds >>> zero_image_emb = out.negative_image_embeds >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ).images >>> image[0].save("cat.png") ``` """ def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor class KandinskyV22Pipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Args: Function invoked when calling the pipeline for generation. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) batch_size = image_embeds.shape[0] * num_images_per_prompt if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.unet.config.in_channels height, width = downscale_height_and_width(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), image_embeds.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import torch from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> import numpy as np >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline >>> from transformers import pipeline >>> from diffusers.utils import load_image >>> def make_hint(image, depth_estimator): ... image = depth_estimator(image)["depth"] ... image = np.array(image) ... image = image[:, :, None] ... image = np.concatenate([image, image, image], axis=2) ... detected_map = torch.from_numpy(image).float() / 255.0 ... hint = detected_map.permute(2, 0, 1) ... return hint >>> depth_estimator = pipeline("depth-estimation") >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior = pipe_prior.to("cuda") >>> pipe = KandinskyV22ControlnetPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> img = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ).resize((768, 768)) >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") >>> prompt = "A robot, 4k photo" >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" >>> generator = torch.Generator(device="cuda").manual_seed(43) >>> image_emb, zero_image_emb = pipe_prior( ... prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator ... ).to_tuple() >>> images = pipe( ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... hint=hint, ... num_inference_steps=50, ... generator=generator, ... height=768, ... width=768, ... ).images >>> images[0].save("robot_cat.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor class KandinskyV22ControlnetPipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], hint: torch.FloatTensor, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. hint (`torch.FloatTensor`): The controlnet condition. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if isinstance(hint, list): hint = torch.cat(hint, dim=0) batch_size = image_embeds.shape[0] * num_images_per_prompt if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) hint = hint.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.movq.config.latent_channels height, width = downscale_height_and_width(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), image_embeds.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import numpy as np import PIL import torch from PIL import Image from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> import numpy as np >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline >>> from transformers import pipeline >>> from diffusers.utils import load_image >>> def make_hint(image, depth_estimator): ... image = depth_estimator(image)["depth"] ... image = np.array(image) ... image = image[:, :, None] ... image = np.concatenate([image, image, image], axis=2) ... detected_map = torch.from_numpy(image).float() / 255.0 ... hint = detected_map.permute(2, 0, 1) ... return hint >>> depth_estimator = pipeline("depth-estimation") >>> pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior = pipe_prior.to("cuda") >>> pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> img = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ).resize((768, 768)) >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") >>> prompt = "A robot, 4k photo" >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" >>> generator = torch.Generator(device="cuda").manual_seed(43) >>> img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator) >>> negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator) >>> images = pipe( ... image=img, ... strength=0.5, ... image_embeds=img_emb.image_embeds, ... negative_image_embeds=negative_emb.image_embeds, ... hint=hint, ... num_inference_steps=50, ... generator=generator, ... height=768, ... width=768, ... ).images >>> images[0].save("robot_cat.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image def prepare_image(pil_image, w=512, h=512): pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) arr = np.array(pil_image.convert("RGB")) arr = arr.astype(np.float32) / 127.5 - 1 arr = np.transpose(arr, [2, 0, 1]) image = torch.from_numpy(arr).unsqueeze(0) return image class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): """ Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2_img2img.KandinskyV22Img2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.movq.encode(image).latent_dist.sample(generator) init_latents = self.movq.config.scaling_factor * init_latents init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], hint: torch.FloatTensor, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, strength: float = 0.3, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. hint (`torch.FloatTensor`): The controlnet condition. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if isinstance(hint, list): hint = torch.cat(hint, dim=0) batch_size = image_embeds.shape[0] if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) hint = hint.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) if not isinstance(image, list): image = [image] if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" ) image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = image.to(dtype=image_embeds.dtype, device=device) latents = self.movq.encode(image)["latents"] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) height, width = downscale_height_and_width(height, width, self.movq_scale_factor) latents = self.prepare_latents( latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator ) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import numpy as np import PIL import torch from PIL import Image from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Img2ImgPipeline, KandinskyV22PriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "A red cartoon frog, 4k" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyV22Img2ImgPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/frog.png" ... ) >>> image = pipe( ... image=init_image, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... strength=0.2, ... ).images >>> image[0].save("red_frog.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image def prepare_image(pil_image, w=512, h=512): pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) arr = np.array(pil_image.convert("RGB")) arr = arr.astype(np.float32) / 127.5 - 1 arr = np.transpose(arr, [2, 0, 1]) image = torch.from_numpy(arr).unsqueeze(0) return image class KandinskyV22Img2ImgPipeline(DiffusionPipeline): """ Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.movq.encode(image).latent_dist.sample(generator) init_latents = self.movq.config.scaling_factor * init_latents init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, strength: float = 0.3, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) batch_size = image_embeds.shape[0] if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) if not isinstance(image, list): image = [image] if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" ) image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = image.to(dtype=image_embeds.dtype, device=device) latents = self.movq.encode(image)["latents"] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) height, width = downscale_height_and_width(height, width, self.movq_scale_factor) latents = self.prepare_latents( latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator ) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from typing import List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from PIL import Image from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> import numpy as np >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "a hat" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyV22InpaintPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> mask = np.ones((768, 768), dtype=np.float32) >>> mask[:250, 250:-250] = 0 >>> out = pipe( ... image=init_image, ... mask_image=mask, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ) >>> image = out.images[0] >>> image.save("cat_with_hat.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask def prepare_mask(masks): prepared_masks = [] for mask in masks: old_mask = deepcopy(mask) for i in range(mask.shape[1]): for j in range(mask.shape[2]): if old_mask[0][i][j] == 1: continue if i != 0: mask[:, i - 1, j] = 0 if j != 0: mask[:, i, j - 1] = 0 if i != 0 and j != 0: mask[:, i - 1, j - 1] = 0 if i != mask.shape[1] - 1: mask[:, i + 1, j] = 0 if j != mask.shape[2] - 1: mask[:, i, j + 1] = 0 if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: mask[:, i + 1, j + 1] = 0 prepared_masks.append(mask) return torch.stack(prepared_masks, dim=0) # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask_and_masked_image def prepare_mask_and_masked_image(image, mask, height, width): r""" Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if image is None: raise ValueError("`image` input cannot be undefined.") if mask is None: raise ValueError("`mask_image` input cannot be undefined.") if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) return mask, image class KandinskyV22InpaintPipeline(DiffusionPipeline): """ Pipeline for text-guided image inpainting using Kandinsky2.1 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Args: Function invoked when calling the pipeline for generation. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`np.array`): Tensor representing an image batch, to mask `image`. Black pixels in the mask will be repainted, while white pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) batch_size = image_embeds.shape[0] * num_images_per_prompt if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps # preprocess image and mask mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) image = image.to(dtype=image_embeds.dtype, device=device) image = self.movq.encode(image)["latents"] mask_image = mask_image.to(dtype=image_embeds.dtype, device=device) image_shape = tuple(image.shape[-2:]) mask_image = F.interpolate( mask_image, image_shape, mode="nearest", ) mask_image = prepare_mask(mask_image) masked_image = image * mask_image mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: mask_image = mask_image.repeat(2, 1, 1, 1) masked_image = masked_image.repeat(2, 1, 1, 1) num_channels_latents = self.movq.config.latent_channels height, width = downscale_height_and_width(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), image_embeds.dtype, device, generator, latents, self.scheduler, ) noise = torch.clone(latents) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) added_cond_kwargs = {"image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] init_latents_proper = image[:1] init_mask = mask_image[:1] if i < len(timesteps_tensor) - 1: noise_timestep = timesteps_tensor[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = init_mask * init_latents_proper + (1 - init_mask) * latents # post-processing latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py ================================================ from typing import List, Optional, Union import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( is_accelerate_available, logging, randn_tensor, replace_example_docstring, ) from ..kandinsky import KandinskyPriorPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline >>> import torch >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> image_emb, negative_image_emb = pipe_prior(prompt).to_tuple() >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ).images >>> image[0].save("cat.png") ``` """ EXAMPLE_INTERPOLATE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline >>> from diffusers.utils import load_image >>> import PIL >>> import torch >>> from torchvision import transforms >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> img1 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/starry_night.jpeg" ... ) >>> images_texts = ["a cat", img1, img2] >>> weights = [0.3, 0.3, 0.4] >>> out = pipe_prior.interpolate(images_texts, weights) >>> pipe = KandinskyV22Pipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=out.image_embeds, ... negative_image_embeds=out.negative_image_embeds, ... height=768, ... width=768, ... num_inference_steps=50, ... ).images[0] >>> image.save("starry_cat.png") ``` """ class KandinskyV22PriorPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. image_processor ([`CLIPImageProcessor`]): A image_processor to be used to preprocess image from clip. """ def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: UnCLIPScheduler, image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, ) @torch.no_grad() @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) def interpolate( self, images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], weights: List[float], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, negative_prior_prompt: Optional[str] = None, negative_prompt: Union[str] = "", guidance_scale: float = 4.0, device=None, ): """ Function invoked when using the prior pipeline for interpolation. Args: images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): list of prompts and images to guide the image generation. weights: (`List[float]`): list of weights for each condition in `images_and_prompts` num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. negative_prior_prompt (`str`, *optional*): The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ device = device or self.device if len(images_and_prompts) != len(weights): raise ValueError( f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" ) image_embeddings = [] for cond, weight in zip(images_and_prompts, weights): if isinstance(cond, str): image_emb = self( cond, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ).image_embeds.unsqueeze(0) elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): if isinstance(cond, PIL.Image.Image): cond = ( self.image_processor(cond, return_tensors="pt") .pixel_values[0] .unsqueeze(0) .to(dtype=self.image_encoder.dtype, device=device) ) image_emb = self.image_encoder(cond)["image_embeds"].repeat(num_images_per_prompt, 1).unsqueeze(0) else: raise ValueError( f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" ) image_embeddings.append(image_emb * weight) image_emb = torch.cat(image_embeddings).sum(dim=0) out_zero = self( negative_prompt, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ) zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed def get_zero_embed(self, batch_size=1, device=None): device = device or self.device zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( device=device, dtype=self.image_encoder.dtype ) zero_image_emb = self.image_encoder(zero_img)["image_embeds"] zero_image_emb = zero_image_emb.repeat(batch_size, 1) return zero_image_emb # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.image_encoder, self.text_encoder, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): return self.device for module in self.text_encoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, output_type: Optional[str] = "pt", # pt only return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] elif not isinstance(negative_prompt, list) and negative_prompt is not None: raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") # if the negative prompt is defined we double the batch size to # directly retrieve the negative prompt embedding if negative_prompt is not None: prompt = prompt + negative_prompt negative_prompt = 2 * negative_prompt device = self._execution_device batch_size = len(prompt) batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) prior_timesteps_tensor = self.scheduler.timesteps embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, embedding_dim), prompt_embeds.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == prior_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = prior_timesteps_tensor[i + 1] latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample latents = self.prior.post_process_latents(latents) image_embeddings = latents # if negative prompt has been defined, we retrieve split the image embedding into two if negative_prompt is None: zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) else: image_embeddings, zero_embeds = image_embeddings.chunk(2) if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": image_embeddings = image_embeddings.cpu().numpy() zero_embeds = zero_embeds.cpu().numpy() if not return_dict: return (image_embeddings, zero_embeds) return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) ================================================ FILE: 4DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py ================================================ from typing import List, Optional, Union import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( is_accelerate_available, logging, randn_tensor, replace_example_docstring, ) from ..kandinsky import KandinskyPriorPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> img = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple() >>> pipe = KandinskyPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder, torch_dtype=torch.float16" ... ) >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images >>> image[0].save("cat.png") ``` """ EXAMPLE_INTERPOLATE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22Pipeline >>> from diffusers.utils import load_image >>> import PIL >>> import torch >>> from torchvision import transforms >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> img1 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/starry_night.jpeg" ... ) >>> images_texts = ["a cat", img1, img2] >>> weights = [0.3, 0.3, 0.4] >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) >>> pipe = KandinskyV22Pipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=150, ... ).images[0] >>> image.save("starry_cat.png") ``` """ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: UnCLIPScheduler, image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start @torch.no_grad() @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) def interpolate( self, images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], weights: List[float], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, negative_prior_prompt: Optional[str] = None, negative_prompt: Union[str] = "", guidance_scale: float = 4.0, device=None, ): """ Function invoked when using the prior pipeline for interpolation. Args: images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): list of prompts and images to guide the image generation. weights: (`List[float]`): list of weights for each condition in `images_and_prompts` num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. negative_prior_prompt (`str`, *optional*): The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ device = device or self.device if len(images_and_prompts) != len(weights): raise ValueError( f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" ) image_embeddings = [] for cond, weight in zip(images_and_prompts, weights): if isinstance(cond, str): image_emb = self( cond, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ).image_embeds.unsqueeze(0) elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): image_emb = self._encode_image( cond, device=device, num_images_per_prompt=num_images_per_prompt ).unsqueeze(0) else: raise ValueError( f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" ) image_embeddings.append(image_emb * weight) image_emb = torch.cat(image_embeddings).sum(dim=0) return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=torch.randn_like(image_emb)) def _encode_image( self, image: Union[torch.Tensor, List[PIL.Image.Image]], device, num_images_per_prompt, ): if not isinstance(image, torch.Tensor): image = self.image_processor(image, return_tensors="pt").pixel_values.to( dtype=self.image_encoder.dtype, device=device ) image_emb = self.image_encoder(image)["image_embeds"] # B, D image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0) image_emb.to(device=device) return image_emb def prepare_latents(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): emb = emb.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt init_latents = emb if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed def get_zero_embed(self, batch_size=1, device=None): device = device or self.device zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( device=device, dtype=self.image_encoder.dtype ) zero_image_emb = self.image_encoder(zero_img)["image_embeds"] zero_image_emb = zero_image_emb.repeat(batch_size, 1) return zero_image_emb # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.image_encoder, self.text_encoder, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): return self.device for module in self.text_encoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], strength: float = 0.3, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, output_type: Optional[str] = "pt", # pt only return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `emb`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. emb (`torch.FloatTensor`): The image embedding. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] elif not isinstance(negative_prompt, list) and negative_prompt is not None: raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") # if the negative prompt is defined we double the batch size to # directly retrieve the negative prompt embedding if negative_prompt is not None: prompt = prompt + negative_prompt negative_prompt = 2 * negative_prompt device = self._execution_device batch_size = len(prompt) batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if not isinstance(image, List): image = [image] if isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) if isinstance(image, torch.Tensor) and image.ndim == 2: # allow user to pass image_embeds directly image_embeds = image.repeat_interleave(num_images_per_prompt, dim=0) elif isinstance(image, torch.Tensor) and image.ndim != 4: raise ValueError( f" if pass `image` as pytorch tensor, or a list of pytorch tensor, please make sure each tensor has shape [batch_size, channels, height, width], currently {image[0].unsqueeze(0).shape}" ) else: image_embeds = self._encode_image(image, device, num_images_per_prompt) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) latents = image_embeds timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size) latents = self.prepare_latents( latents, latent_timestep, batch_size // num_images_per_prompt, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == timesteps.shape[0]: prev_timestep = None else: prev_timestep = timesteps[i + 1] latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample latents = self.prior.post_process_latents(latents) image_embeddings = latents # if negative prompt has been defined, we retrieve split the image embedding into two if negative_prompt is None: zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) else: image_embeddings, zero_embeds = image_embeddings.chunk(2) if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": image_embeddings = image_embeddings.cpu().numpy() zero_embeds = zero_embeds.cpu().numpy() if not return_dict: return (image_embeddings, zero_embeds) return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) ================================================ FILE: 4DoF/diffusers/pipelines/latent_diffusion/__init__.py ================================================ from ...utils import is_transformers_available from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline if is_transformers_available(): from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline ================================================ FILE: 4DoF/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput from transformers.utils import logging from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class LDMTextToImagePipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. tokenizer (`transformers.BertTokenizer`): Tokenizer of class [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vqvae: Union[VQModel, AutoencoderKL], bert: PreTrainedModel, tokenizer: PreTrainedTokenizer, unet: Union[UNet2DModel, UNet2DConditionModel], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], ): super().__init__() self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 1.0, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: r""" Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at the, usually at the expense of lower image quality. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get unconditional embeddings for classifier free guidance if guidance_scale != 1.0: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ) negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self.device))[0] # get prompt text embeddings text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0] # get the initial random noise unless the user supplied it latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=prompt_embeds.dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") latents = latents.to(self.device) self.scheduler.set_timesteps(num_inference_steps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_kwargs = {} if accepts_eta: extra_kwargs["eta"] = eta for t in self.progress_bar(self.scheduler.timesteps): if guidance_scale == 1.0: # guidance_scale of 1 means no guidance latents_input = latents context = prompt_embeds else: # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = torch.cat([latents] * 2) context = torch.cat([negative_prompt_embeds, prompt_embeds]) # predict the noise residual noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample # perform guidance if guidance_scale != 1.0: noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / self.vqvae.config.scaling_factor * latents image = self.vqvae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ################################################################################ # Code for the text transformer model ################################################################################ """ PyTorch LDMBERT model.""" logger = logging.get_logger(__name__) LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ "ldm-bert", # See all LDMBert models at https://huggingface.co/models?filter=ldmbert ] LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json", } """ LDMBERT model configuration""" class LDMBertConfig(PretrainedConfig): model_type = "ldmbert" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} def __init__( self, vocab_size=30522, max_position_embeddings=77, encoder_layers=32, encoder_ffn_dim=5120, encoder_attention_heads=8, head_dim=64, encoder_layerdrop=0.0, activation_function="gelu", d_model=1280, dropout=0.1, attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, classifier_dropout=0.0, scale_embedding=False, use_cache=True, pad_token_id=0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.d_model = d_model self.encoder_ffn_dim = encoder_ffn_dim self.encoder_layers = encoder_layers self.encoder_attention_heads = encoder_attention_heads self.head_dim = head_dim self.dropout = dropout self.attention_dropout = attention_dropout self.activation_dropout = activation_dropout self.activation_function = activation_function self.init_std = init_std self.encoder_layerdrop = encoder_layerdrop self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__(pad_token_id=pad_token_id, **kwargs) def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert class LDMBertAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, embed_dim: int, num_heads: int, head_dim: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = False, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = head_dim self.inner_dim = head_dim * num_heads self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) self.out_proj = nn.Linear(self.inner_dim, embed_dim) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value class LDMBertEncoderLayer(nn.Module): def __init__(self, config: LDMBertConfig): super().__init__() self.embed_dim = config.d_model self.self_attn = LDMBertAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, head_dim=config.head_dim, dropout=config.attention_dropout, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.FloatTensor, attention_mask: torch.FloatTensor, layer_head_mask: torch.FloatTensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs # Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert class LDMBertPreTrainedModel(PreTrainedModel): config_class = LDMBertConfig base_model_prefix = "model" _supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (LDMBertEncoder,)): module.gradient_checkpointing = value @property def dummy_inputs(self): pad_token = self.config.pad_token_id input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) dummy_inputs = { "attention_mask": input_ids.ne(pad_token), "input_ids": input_ids, } return dummy_inputs class LDMBertEncoder(LDMBertPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`LDMBertEncoderLayer`]. Args: config: LDMBertConfig embed_tokens (nn.Embedding): output embedding """ def __init__(self, config: LDMBertConfig): super().__init__(config) self.dropout = config.dropout embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) seq_len = input_shape[1] if position_ids is None: position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) embed_pos = self.embed_positions(position_ids) hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: if head_mask.size()[0] != (len(self.layers)): raise ValueError( f"The head_mask should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) class LDMBertModel(LDMBertPreTrainedModel): _no_split_modules = [] def __init__(self, config: LDMBertConfig): super().__init__(config) self.model = LDMBertEncoder(config) self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return outputs ================================================ FILE: 4DoF/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py ================================================ import inspect from typing import List, Optional, Tuple, Union import numpy as np import PIL import torch import torch.utils.checkpoint from ...models import UNet2DModel, VQModel from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from ...utils import PIL_INTERPOLATION, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput def preprocess(image): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 class LDMSuperResolutionPipeline(DiffusionPipeline): r""" A pipeline for image super-resolution using Latent This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations. unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vqvae: VQModel, unet: UNet2DModel, scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], ): super().__init__() self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, image: Union[torch.Tensor, PIL.Image.Image] = None, batch_size: Optional[int] = 1, num_inference_steps: Optional[int] = 100, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[Tuple, ImagePipelineOutput]: r""" Args: image (`torch.Tensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. batch_size (`int`, *optional*, defaults to 1): Number of images to generate. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, torch.Tensor): batch_size = image.shape[0] else: raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") if isinstance(image, PIL.Image.Image): image = preprocess(image) height, width = image.shape[-2:] # in_channels should be 6: 3 for latents, 3 for low resolution image latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width) latents_dtype = next(self.unet.parameters()).dtype latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) image = image.to(device=self.device, dtype=latents_dtype) # set timesteps and move to the correct device self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps_tensor = self.scheduler.timesteps # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_kwargs = {} if accepts_eta: extra_kwargs["eta"] = eta for t in self.progress_bar(timesteps_tensor): # concat latents and low resolution image in the channel dimension. latents_input = torch.cat([latents, image], dim=1) latents_input = self.scheduler.scale_model_input(latents_input, t) # predict the noise residual noise_pred = self.unet(latents_input, t).sample # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample # decode the image latents with the VQVAE image = self.vqvae.decode(latents).sample image = torch.clamp(image, -1.0, 1.0) image = image / 2 + 0.5 image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/latent_diffusion_uncond/__init__.py ================================================ from .pipeline_latent_diffusion_uncond import LDMPipeline ================================================ FILE: 4DoF/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Tuple, Union import torch from ...models import UNet2DModel, VQModel from ...schedulers import DDIMScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class LDMPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents. """ def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): super().__init__() self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, eta: float = 0.0, num_inference_steps: int = 50, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: r""" Args: batch_size (`int`, *optional*, defaults to 1): Number of images to generate. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ latents = randn_tensor( (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, ) latents = latents.to(self.device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma self.scheduler.set_timesteps(num_inference_steps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_kwargs = {} if accepts_eta: extra_kwargs["eta"] = eta for t in self.progress_bar(self.scheduler.timesteps): latent_model_input = self.scheduler.scale_model_input(latents, t) # predict the noise residual noise_prediction = self.unet(latent_model_input, t).sample # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample # decode the image latents with the VAE image = self.vqvae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/onnx_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil from pathlib import Path from typing import Optional, Union import numpy as np from huggingface_hub import hf_hub_download from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging if is_onnx_available(): import onnxruntime as ort logger = logging.get_logger(__name__) ORT_TO_NP_TYPE = { "tensor(bool)": np.bool_, "tensor(int8)": np.int8, "tensor(uint8)": np.uint8, "tensor(int16)": np.int16, "tensor(uint16)": np.uint16, "tensor(int32)": np.int32, "tensor(uint32)": np.uint32, "tensor(int64)": np.int64, "tensor(uint64)": np.uint64, "tensor(float16)": np.float16, "tensor(float)": np.float32, "tensor(double)": np.float64, } class OnnxRuntimeModel: def __init__(self, model=None, **kwargs): logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") self.model = model self.model_save_dir = kwargs.get("model_save_dir", None) self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) def __call__(self, **kwargs): inputs = {k: np.array(v) for k, v in kwargs.items()} return self.model.run(None, inputs) @staticmethod def load_model(path: Union[str, Path], provider=None, sess_options=None): """ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` Arguments: path (`str` or `Path`): Directory from which to load provider(`str`, *optional*): Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` """ if provider is None: logger.info("No onnxruntime provider specified, using CPUExecutionProvider") provider = "CPUExecutionProvider" return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the latest_model_name. Arguments: save_directory (`str` or `Path`): Directory where to save the model file. file_name(`str`, *optional*): Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the model with a different name. """ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME src_path = self.model_save_dir.joinpath(self.latest_model_name) dst_path = Path(save_directory).joinpath(model_file_name) try: shutil.copyfile(src_path, dst_path) except shutil.SameFileError: pass # copy external weights (for models >2GB) src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) if src_path.exists(): dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) try: shutil.copyfile(src_path, dst_path) except shutil.SameFileError: pass def save_pretrained( self, save_directory: Union[str, os.PathLike], **kwargs, ): """ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class method.: Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) # saving model weights/files self._save_pretrained(save_directory, **kwargs) @classmethod def _from_pretrained( cls, model_id: Union[str, Path], use_auth_token: Optional[Union[bool, str, None]] = None, revision: Optional[Union[str, None]] = None, force_download: bool = False, cache_dir: Optional[str] = None, file_name: Optional[str] = None, provider: Optional[str] = None, sess_options: Optional["ort.SessionOptions"] = None, **kwargs, ): """ Load a model from a directory or the HF Hub. Arguments: model_id (`str` or `Path`): Directory from which to load use_auth_token (`str` or `bool`): Is needed to load models from a private or gated repository revision (`str`): Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id cache_dir (`Union[str, Path]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. file_name(`str`): Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load different model files from the same repository or directory. provider(`str`): The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. kwargs (`Dict`, *optional*): kwargs will be passed to the model during initialization """ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME # load model from local directory if os.path.isdir(model_id): model = OnnxRuntimeModel.load_model( os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options ) kwargs["model_save_dir"] = Path(model_id) # load model from hub else: # download model model_cache_path = hf_hub_download( repo_id=model_id, filename=model_file_name, use_auth_token=use_auth_token, revision=revision, cache_dir=cache_dir, force_download=force_download, ) kwargs["model_save_dir"] = Path(model_cache_path).parent kwargs["latest_model_name"] = Path(model_cache_path).name model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options) return cls(model=model, **kwargs) @classmethod def from_pretrained( cls, model_id: Union[str, Path], force_download: bool = True, use_auth_token: Optional[str] = None, cache_dir: Optional[str] = None, **model_kwargs, ): revision = None if len(str(model_id).split("@")) == 2: model_id, revision = model_id.split("@") return cls._from_pretrained( model_id=model_id, revision=revision, cache_dir=cache_dir, force_download=force_download, use_auth_token=use_auth_token, **model_kwargs, ) ================================================ FILE: 4DoF/diffusers/pipelines/paint_by_example/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import is_torch_available, is_transformers_available if is_transformers_available() and is_torch_available(): from .image_encoder import PaintByExampleImageEncoder from .pipeline_paint_by_example import PaintByExamplePipeline ================================================ FILE: 4DoF/diffusers/pipelines/paint_by_example/image_encoder.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from transformers import CLIPPreTrainedModel, CLIPVisionModel from ...models.attention import BasicTransformerBlock from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name class PaintByExampleImageEncoder(CLIPPreTrainedModel): def __init__(self, config, proj_size=768): super().__init__(config) self.proj_size = proj_size self.model = CLIPVisionModel(config) self.mapper = PaintByExampleMapper(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size) self.proj_out = nn.Linear(config.hidden_size, self.proj_size) # uncondition for scaling self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) def forward(self, pixel_values, return_uncond_vector=False): clip_output = self.model(pixel_values=pixel_values) latent_states = clip_output.pooler_output latent_states = self.mapper(latent_states[:, None]) latent_states = self.final_layer_norm(latent_states) latent_states = self.proj_out(latent_states) if return_uncond_vector: return latent_states, self.uncond_vector return latent_states class PaintByExampleMapper(nn.Module): def __init__(self, config): super().__init__() num_layers = (config.num_hidden_layers + 1) // 5 hid_size = config.hidden_size num_heads = 1 self.blocks = nn.ModuleList( [ BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) for _ in range(num_layers) ] ) def forward(self, hidden_states): for block in self.blocks: hidden_states = block(hidden_states) return hidden_states ================================================ FILE: 4DoF/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor from diffusers.utils import is_accelerate_available from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .image_encoder import PaintByExampleImageEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name def prepare_mask_and_masked_image(image, mask): """ Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Batched mask if mask.shape[0] == image.shape[0]: mask = mask.unsqueeze(1) else: mask = mask.unsqueeze(0) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" assert mask.shape[1] == 1, "Mask image must have a single channel" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # paint-by-example inverses the mask mask = 1 - mask # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: if isinstance(image, PIL.Image.Image): image = [image] image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, PIL.Image.Image): mask = [mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 # paint-by-example inverses the mask mask = 1 - mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) masked_image = image * mask return mask, masked_image class PaintByExamplePipeline(DiffusionPipeline): r""" Pipeline for image-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. image_encoder ([`PaintByExampleImageEncoder`]): Encodes the example input image. The unet is conditioned on the example image instead of a text prompt. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ # TODO: feature_extractor is required to encode initial images (if they are in PIL format), # we should give a descriptive message if the pipeline doesn't have one. _optional_components = ["safety_checker"] def __init__( self, vae: AutoencoderKL, image_encoder: PaintByExampleImageEncoder, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = False, ): super().__init__() self.register_modules( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.vae, self.image_encoder]: cpu_offload(cpu_offloaded_model, execution_device=device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = self.vae.config.scaling_factor * image_latents return image_latents def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1) negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings @torch.no_grad() def __call__( self, example_image: Union[torch.FloatTensor, PIL.Image.Image], image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.FloatTensor, PIL.Image.Image], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: example_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): The exemplar image to guide the image generation. image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Define call parameters if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) else: batch_size = image.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 2. Preprocess mask and image mask, masked_image = prepare_mask_and_masked_image(image, mask_image) height, width = masked_image.shape[-2:] # 3. Check inputs self.check_inputs(example_image, height, width, callback_steps) # 4. Encode input image image_embeddings = self._encode_image( example_image, device, num_images_per_prompt, do_classifier_free_guidance ) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, image_embeddings.dtype, device, generator, do_classifier_free_guidance, ) # 8. Check that sizes of mask, masked image and latents match num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, masked_image_latents, mask], dim=1) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/pipeline_flax_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect import os from typing import Any, Dict, List, Optional, Union import flax import numpy as np import PIL from flax.core.frozen_dict import FrozenDict from huggingface_hub import snapshot_download from PIL import Image from tqdm.auto import tqdm from ..configuration_utils import ConfigMixin from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging if is_transformers_available(): from transformers import FlaxPreTrainedModel INDEX_FILE = "diffusion_flax_model.bin" logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "FlaxModelMixin": ["save_pretrained", "from_pretrained"], "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"], "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], }, "transformers": { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], "ProcessorMixin": ["save_pretrained", "from_pretrained"], "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], }, } ALL_IMPORTABLE_CLASSES = {} for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) def import_flax_or_no_model(module, class_name): try: # 1. First make sure that if a Flax object is present, import this one class_obj = getattr(module, "Flax" + class_name) except AttributeError: # 2. If this doesn't work, it's not a model and we don't append "Flax" class_obj = getattr(module, class_name) except AttributeError: raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}") return class_obj @flax.struct.dataclass class FlaxImagePipelineOutput(BaseOutput): """ Output class for image pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. """ images: Union[List[PIL.Image.Image], np.ndarray] class FlaxDiffusionPipeline(ConfigMixin): r""" Base class for all models. [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: - enabling/disabling the progress bar for the denoising iteration Class attributes: - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all components of the diffusion pipeline. """ config_name = "model_index.json" def register_modules(self, **kwargs): # import it here to avoid circular import from diffusers import pipelines for name, module in kwargs.items(): if module is None: register_dict = {name: (None, None)} else: # retrieve library library = module.__module__.split(".")[0] # check if the module is a pipeline module pipeline_dir = module.__module__.split(".")[-2] path = module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline # folder so we set the library to module name. if library not in LOADABLE_CLASSES or is_pipeline_module: library = pipeline_dir # retrieve class_name class_name = module.__class__.__name__ register_dict = {name: (library, class_name)} # save model index config self.register_to_config(**register_dict) # set models setattr(self, name, module) def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]): # TODO: handle inference_state """ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. """ self.save_config(save_directory) model_index_dict = dict(self.config) model_index_dict.pop("_class_name") model_index_dict.pop("_diffusers_version") model_index_dict.pop("_module", None) for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) if sub_model is None: # edge case for saving a pipeline with safety_checker=None continue model_cls = sub_model.__class__ save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): library = importlib.import_module(library_name) for base_class, save_load_methods in library_classes.items(): class_candidate = getattr(library, base_class, None) if class_candidate is not None and issubclass(model_cls, class_candidate): # if we found a suitable base class in LOADABLE_CLASSES then grab its save method save_method_name = save_load_methods[0] break if save_method_name is not None: break save_method = getattr(sub_model, save_method_name) expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) if expects_params: save_method( os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name] ) else: save_method(os.path.join(save_directory, pipeline_component_name)) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a Flax diffusion pipeline from pre-trained pipeline weights. The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like `CompVis/ldm-text2im-large-256`. - A path to a *directory* containing pipeline weights saved using [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. dtype (`str` or `jnp.dtype`, *optional*): Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype will be automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. mirror (`str`, *optional*): Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. specify the folder name here. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the specific pipeline class. The overwritten components are then directly passed to the pipelines `__init__` method. See example below for more information. It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. Examples: ```py >>> from diffusers import FlaxDiffusionPipeline >>> # Download pipeline from huggingface.co and cache. >>> # Requires to be logged in to Hugging Face hub, >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens) >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", ... revision="bf16", ... dtype=jnp.bfloat16, ... ) >>> # Download pipeline, but use a different scheduler >>> from diffusers import FlaxDPMSolverMultistepScheduler >>> model_id = "runwayml/stable-diffusion-v1-5" >>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained( ... model_id, ... subfolder="scheduler", ... ) >>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained( ... model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp ... ) >>> dpm_params["scheduler"] = dpmpp_state ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pt = kwargs.pop("from_pt", False) use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): config_dict = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, ) # make sure we only download sub-folders and `diffusers` filenames folder_names = [k for k in config_dict.keys() if not k.startswith("_")] allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] # make sure we don't download PyTorch weights, unless when using from_pt ignore_patterns = "*.bin" if not from_pt else [] if cls != FlaxDiffusionPipeline: requested_pipeline_class = cls.__name__ else: requested_pipeline_class = config_dict.get("_class_name", cls.__name__) requested_pipeline_class = ( requested_pipeline_class if requested_pipeline_class.startswith("Flax") else "Flax" + requested_pipeline_class ) user_agent = {"pipeline_class": requested_pipeline_class} user_agent = http_user_agent(user_agent) # download all allow_patterns cached_folder = snapshot_download( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, user_agent=user_agent, ) else: cached_folder = pretrained_model_name_or_path config_dict = cls.load_config(cached_folder) # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it if cls != FlaxDiffusionPipeline: pipeline_class = cls else: diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) class_name = ( config_dict["_class_name"] if config_dict["_class_name"].startswith("Flax") else "Flax" + config_dict["_class_name"] ) pipeline_class = getattr(diffusers_module, class_name) # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} # inference_params params = {} # import it here to avoid circular import from diffusers import pipelines # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): if class_name is None: # edge case for when the pipeline was saved with safety_checker=None init_kwargs[name] = None continue is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None sub_model_should_be_defined = True # if the model is in a pipeline module, then we load it from the pipeline if name in passed_class_obj: # 1. check that passed_class_obj has correct parent class if not is_pipeline_module: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} expected_class_obj = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate if not issubclass(passed_class_obj[name].__class__, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" ) elif passed_class_obj[name] is None: logger.warning( f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" f" that this might lead to problems when using {pipeline_class} and is not recommended." ) sub_model_should_be_defined = False else: logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) # set passed class object loaded_sub_model = passed_class_obj[name] elif is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = import_flax_or_no_model(pipeline_module, class_name) importable_classes = ALL_IMPORTABLE_CLASSES class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) class_obj = import_flax_or_no_model(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} if loaded_sub_model is None and sub_model_should_be_defined: load_method_name = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] load_method = getattr(class_obj, load_method_name) # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): loadable_folder = os.path.join(cached_folder, name) else: loaded_sub_model = cached_folder if issubclass(class_obj, FlaxModelMixin): loaded_sub_model, loaded_params = load_method( loadable_folder, from_pt=from_pt, use_memory_efficient_attention=use_memory_efficient_attention, dtype=dtype, ) params[name] = loaded_params elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): if from_pt: # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) loaded_params = loaded_sub_model.params del loaded_sub_model._params else: loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params elif issubclass(class_obj, FlaxSchedulerMixin): loaded_sub_model, scheduler_state = load_method(loadable_folder) params[name] = scheduler_state else: loaded_sub_model = load_method(loadable_folder) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) # 4. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) if len(missing_modules) > 0 and missing_modules <= set(passed_modules): for module in missing_modules: init_kwargs[module] = passed_class_obj.get(module, None) elif len(missing_modules) > 0: passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs raise ValueError( f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) model = pipeline_class(**init_kwargs, dtype=dtype) return model, params @staticmethod def _get_signature_keys(obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} return expected_modules, optional_parameters @property def components(self) -> Dict[str, Any]: r""" The `self.components` property can be useful to run different pipelines with the same weights and configurations to not have to re-allocate memory. Examples: ```py >>> from diffusers import ( ... FlaxStableDiffusionPipeline, ... FlaxStableDiffusionImg2ImgPipeline, ... ) >>> text2img = FlaxStableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16 ... ) >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components) ``` Returns: A dictionary containing all the modules needed to initialize the pipeline. """ expected_modules, optional_parameters = self._get_signature_keys(self) components = { k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters } if set(components.keys()) != expected_modules: raise ValueError( f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" f" {expected_modules} to be defined, but {components} are defined." ) return components @staticmethod def numpy_to_pil(images): """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images # TODO: make it compatible with jax.lax def progress_bar(self, iterable): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): raise ValueError( f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) return tqdm(iterable, **self._progress_bar_config) def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs ================================================ FILE: 4DoF/diffusers/pipelines/pipeline_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import fnmatch import importlib import inspect import os import re import sys import warnings from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from huggingface_hub import hf_hub_download, model_info, snapshot_download from packaging import version from requests.exceptions import HTTPError from tqdm.auto import tqdm import diffusers from .. import __version__ from ..configuration_utils import ConfigMixin from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, DIFFUSERS_CACHE, HF_HUB_OFFLINE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, deprecate, get_class_from_dynamic_module, is_accelerate_available, is_accelerate_version, is_compiled_module, is_safetensors_available, is_torch_version, is_transformers_available, logging, numpy_to_pil, ) if is_transformers_available(): import transformers from transformers import PreTrainedModel from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME if is_accelerate_available(): import accelerate INDEX_FILE = "diffusion_pytorch_model.bin" CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" DUMMY_MODULES_FOLDER = "diffusers.utils" TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], "SchedulerMixin": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], }, "transformers": { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"], "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], "ProcessorMixin": ["save_pretrained", "from_pretrained"], "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], }, "onnxruntime.training": { "ORTModule": ["save_pretrained", "from_pretrained"], }, } ALL_IMPORTABLE_CLASSES = {} for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) @dataclass class ImagePipelineOutput(BaseOutput): """ Output class for image pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. """ images: Union[List[PIL.Image.Image], np.ndarray] @dataclass class AudioPipelineOutput(BaseOutput): """ Output class for audio pipelines. Args: audios (`np.ndarray`) List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`. """ audios: np.ndarray def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: """ Checking for safetensors compatibility: - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch files to know which safetensors files are needed. - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. Converting default pytorch serialized filenames to safetensors serialized filenames: - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" extension is replaced with ".safetensors" """ pt_filenames = [] sf_filenames = set() passed_components = passed_components or [] for filename in filenames: _, extension = os.path.splitext(filename) if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: continue if extension == ".bin": pt_filenames.append(filename) elif extension == ".safetensors": sf_filenames.add(filename) for filename in pt_filenames: # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam' path, filename = os.path.split(filename) filename, extension = os.path.splitext(filename) if filename.startswith("pytorch_model"): filename = filename.replace("pytorch_model", "model") else: filename = filename expected_sf_filename = os.path.join(path, filename) expected_sf_filename = f"{expected_sf_filename}.safetensors" if expected_sf_filename not in sf_filenames: logger.warning(f"{expected_sf_filename} not found") return False return True def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ] if is_transformers_available(): weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] # model_pytorch, diffusion_model_pytorch, ... weight_prefixes = [w.split(".")[0] for w in weight_names] # .bin, .safetensors, ... weight_suffixs = [w.split(".")[-1] for w in weight_names] # -00001-of-00002 transformers_index_format = r"\d{5}-of-\d{5}" if variant is not None: # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors` variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" ) # `text_encoder/pytorch_model.bin.index.fp16.json` variant_index_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" ) # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` non_variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" ) # `text_encoder/pytorch_model.bin.index.json` non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") if variant is not None: variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} variant_filenames = variant_weights | variant_indexes else: variant_filenames = set() non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} non_variant_filenames = non_variant_weights | non_variant_indexes # all variant filenames will be used by default usable_filenames = set(variant_filenames) def convert_to_variant(filename): if "index" in filename: variant_filename = filename.replace("index", f"index.{variant}") elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" else: variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" return variant_filename for f in non_variant_filenames: variant_filename = convert_to_variant(f) if variant_filename not in usable_filenames: usable_filenames.add(f) return usable_filenames, variant_filenames def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, variant, revision, model_filenames): info = model_info( pretrained_model_name_or_path, use_auth_token=use_auth_token, revision=None, ) filenames = {sibling.rfilename for sibling in info.siblings} comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] if set(comp_model_filenames) == set(model_filenames): warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", FutureWarning, ) else: warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", FutureWarning, ) def maybe_raise_or_warn( library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module ): """Simple helper method to raise or warn in case incorrect module has been passed""" if not is_pipeline_module: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} expected_class_obj = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. sub_model = passed_class_obj[name] model_cls = sub_model.__class__ if is_compiled_module(sub_model): model_cls = sub_model._orig_mod.__class__ if not issubclass(model_cls, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" ) else: logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module): """Simple helper method to retrieve class object of module as well as potential parent class objects""" if is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} return class_obj, class_candidates def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, revision=None): if custom_pipeline is not None: if custom_pipeline.endswith(".py"): path = Path(custom_pipeline) # decompose into folder & file file_name = path.name custom_pipeline = path.parent.absolute() else: file_name = CUSTOM_PIPELINE_FILE_NAME return get_class_from_dynamic_module( custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision ) if class_obj != DiffusionPipeline: return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) return getattr(diffusers_module, config["_class_name"]) def load_sub_model( library_name: str, class_name: str, importable_classes: List[Any], pipelines: Any, is_pipeline_module: bool, pipeline_class: Any, torch_dtype: torch.dtype, provider: Any, sess_options: Any, device_map: Optional[Union[Dict[str, torch.device], str]], max_memory: Optional[Dict[Union[int, str], Union[int, str]]], offload_folder: Optional[Union[str, os.PathLike]], offload_state_dict: bool, model_variants: Dict[str, str], name: str, from_flax: bool, variant: str, low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], ): """Helper method to load the module `name` from `library_name` and `class_name`""" # retrieve class candidates class_obj, class_candidates = get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module ) load_method_name = None # retrive load method name for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] # if load method name is None, then we have a dummy module -> raise Error if load_method_name is None: none_module = class_obj.__module__ is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( TRANSFORMERS_DUMMY_MODULES_FOLDER ) if is_dummy_path and "dummy" in none_module: # call class_obj for nice error message of missing requirements class_obj() raise ValueError( f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." ) load_method = getattr(class_obj, load_method_name) # add kwargs to loading method loading_kwargs = {} if issubclass(class_obj, torch.nn.Module): loading_kwargs["torch_dtype"] = torch_dtype if issubclass(class_obj, diffusers.OnnxRuntimeModel): loading_kwargs["provider"] = provider loading_kwargs["sess_options"] = sess_options is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) if is_transformers_available(): transformers_version = version.parse(version.parse(transformers.__version__).base_version) else: transformers_version = "N/A" is_transformers_model = ( is_transformers_available() and issubclass(class_obj, PreTrainedModel) and transformers_version >= version.parse("4.20.0") ) # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading. if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map loading_kwargs["max_memory"] = max_memory loading_kwargs["offload_folder"] = offload_folder loading_kwargs["offload_state_dict"] = offload_state_dict loading_kwargs["variant"] = model_variants.pop(name, None) if from_flax: loading_kwargs["from_flax"] = True # the following can be deleted once the minimum required `transformers` version # is higher than 4.27 if ( is_transformers_model and loading_kwargs["variant"] is not None and transformers_version < version.parse("4.27.0") ): raise ImportError( f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" ) elif is_transformers_model and loading_kwargs["variant"] is None: loading_kwargs.pop("variant") # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` if not (from_flax and is_transformers_model): loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage else: loading_kwargs["low_cpu_mem_usage"] = False # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) else: # else load from the root directory loaded_sub_model = load_method(cached_folder, **loading_kwargs) return loaded_sub_model class DiffusionPipeline(ConfigMixin): r""" Base class for all pipelines. [`DiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and provides methods for loading, downloading and saving models. It also includes methods to: - move all PyTorch modules to the device of your choice - enabling/disabling the progress bar for the denoising iteration Class attributes: - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the diffusion pipeline's components. - **_optional_components** (List[`str`]) -- List of all optional components that don't have to be passed to the pipeline to function (should be overridden by subclasses). """ config_name = "model_index.json" _optional_components = [] def register_modules(self, **kwargs): # import it here to avoid circular import from diffusers import pipelines for name, module in kwargs.items(): # retrieve library if module is None: register_dict = {name: (None, None)} else: # register the config from the original module, not the dynamo compiled one if is_compiled_module(module): not_compiled_module = module._orig_mod else: not_compiled_module = module library = not_compiled_module.__module__.split(".")[0] # check if the module is a pipeline module module_path_items = not_compiled_module.__module__.split(".") pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None path = not_compiled_module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline # folder so we set the library to module name. if is_pipeline_module: library = pipeline_dir elif library not in LOADABLE_CLASSES: library = not_compiled_module.__module__ # retrieve class_name class_name = not_compiled_module.__class__.__name__ register_dict = {name: (library, class_name)} # save model index config self.register_to_config(**register_dict) # set models setattr(self, name, module) def __setattr__(self, name: str, value: Any): if name in self.__dict__ and hasattr(self.config, name): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): if value is not None and self.config[name][0] is not None: class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) else: class_library_tuple = (None, None) self.register_to_config(**{name: class_library_tuple}) else: self.register_to_config(**{name: value}) super().__setattr__(name, value) def save_pretrained( self, save_directory: Union[str, os.PathLike], safe_serialization: bool = False, variant: Optional[str] = None, ): """ Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading method. The pipeline is easily reloaded using the [`~DiffusionPipeline.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save a pipeline to. Will be created if it doesn't exist. safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. """ model_index_dict = dict(self.config) model_index_dict.pop("_class_name", None) model_index_dict.pop("_diffusers_version", None) model_index_dict.pop("_module", None) expected_modules, optional_kwargs = self._get_signature_keys(self) def is_saveable_module(name, value): if name not in expected_modules: return False if name in self._optional_components and value[0] is None: return False return True model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) model_cls = sub_model.__class__ # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. if is_compiled_module(sub_model): sub_model = sub_model._orig_mod model_cls = sub_model.__class__ save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): if library_name in sys.modules: library = importlib.import_module(library_name) else: logger.info( f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}" ) for base_class, save_load_methods in library_classes.items(): class_candidate = getattr(library, base_class, None) if class_candidate is not None and issubclass(model_cls, class_candidate): # if we found a suitable base class in LOADABLE_CLASSES then grab its save method save_method_name = save_load_methods[0] break if save_method_name is not None: break if save_method_name is None: logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.") # make sure that unsaveable components are not tried to be loaded afterward self.register_to_config(**{pipeline_component_name: (None, None)}) continue save_method = getattr(sub_model, save_method_name) # Call the save method with the argument safe_serialization only if it's supported save_method_signature = inspect.signature(save_method) save_method_accept_safe = "safe_serialization" in save_method_signature.parameters save_method_accept_variant = "variant" in save_method_signature.parameters save_kwargs = {} if save_method_accept_safe: save_kwargs["safe_serialization"] = safe_serialization if save_method_accept_variant: save_kwargs["variant"] = variant save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) # finally save the config self.save_config(save_directory) def to( self, torch_device: Optional[Union[str, torch.device]] = None, torch_dtype: Optional[torch.dtype] = None, silence_dtype_warnings: bool = False, ): if torch_device is None and torch_dtype is None: return self # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. def module_is_sequentially_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): return False return hasattr(module, "_hf_hook") and not isinstance( module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook) ) def module_is_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): return False return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda": raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." ) # Display a warning in this case (the operation succeeds but the benefits are lost) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) if pipeline_is_offloaded and torch.device(torch_device).type == "cuda": logger.warning( f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." ) module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit if is_loaded_in_8bit and torch_dtype is not None: logger.warning( f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision." ) if is_loaded_in_8bit and torch_device is not None: logger.warning( f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}." ) else: module.to(torch_device, torch_dtype) if ( module.dtype == torch.float16 and str(torch_device) in ["cpu"] and not silence_dtype_warnings and not is_offloaded ): logger.warning( "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" " is not recommended to move them to `cpu` as running them will fail. Please make" " sure to use an accelerator to run the pipeline in inference, due to the lack of" " support for`float16` operations on this device in PyTorch. Please, remove the" " `torch_dtype=torch.float16` argument, or use another device for inference." ) return self @property def device(self) -> torch.device: r""" Returns: `torch.device`: The torch device on which the pipeline is located. """ module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] for module in modules: return module.device return torch.device("cpu") @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights. The pipeline is set in evaluation mode (`model.eval()`) by default. If you get the error message below, you need to finetune the weights for your downstream task: ``` Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. custom_pipeline (`str`, *optional*): 🧪 This is an experimental feature and may change in the future. Can be either: - A string, the *repo id* (for example `hf-internal-testing/diffusers-dummy-pipeline`) of a custom pipeline hosted on the Hub. The repository must contain a file called pipeline.py that defines the custom pipeline. - A string, the *file name* of a community pipeline hosted on GitHub under [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the current main branch of GitHub. - A path to a directory (`./my_pipeline_directory/`) containing a custom pipeline. The directory must contain a file called `pipeline.py` that defines the custom pipeline. For more information on how to load and create custom pipelines, please have a look at [Loading and Adding Custom Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. custom_revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn’t need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if device_map contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example below for more information. variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `huggingface-cli login`. Examples: ```py >>> from diffusers import DiffusionPipeline >>> # Download pipeline from huggingface.co and cache. >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") >>> # Download pipeline that requires an authorization token >>> # For more information on access tokens, please refer to this section >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # Use a different scheduler >>> from diffusers import LMSDiscreteScheduler >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) >>> pipeline.scheduler = scheduler ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_flax = kwargs.pop("from_flax", False) torch_dtype = kwargs.pop("torch_dtype", None) custom_pipeline = kwargs.pop("custom_pipeline", None) custom_revision = kwargs.pop("custom_revision", None) provider = kwargs.pop("provider", None) sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): cached_folder = cls.download( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, force_download=force_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, from_flax=from_flax, use_safetensors=use_safetensors, custom_pipeline=custom_pipeline, custom_revision=custom_revision, variant=variant, **kwargs, ) else: cached_folder = pretrained_model_name_or_path config_dict = cls.load_config(cached_folder) # pop out "_ignore_files" as it is only needed for download config_dict.pop("_ignore_files", None) # 2. Define which model components should load variants # We retrieve the information by matching whether variant # model checkpoints exist in the subfolders model_variants = {} if variant is not None: for folder in os.listdir(cached_folder): folder_path = os.path.join(cached_folder, folder) is_folder = os.path.isdir(folder_path) and folder in config_dict variant_exists = is_folder and any( p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) ) if variant_exists: model_variants[folder] = variant # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it pipeline_class = _get_pipeline_class( cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision ) # DEPRECATED: To be removed in 1.0.0 if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( version.parse(config_dict["_diffusers_version"]).base_version ) <= version.parse("0.5.1"): from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy pipeline_class = StableDiffusionInpaintPipelineLegacy deprecation_message = ( "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" f" checkpoint {pretrained_model_name_or_path} to the format of" " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." ) deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) # 4. Define expected modules given pipeline signature # and define non-None initialized modules (=`init_kwargs`) # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) # define init kwargs init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} init_kwargs = {**init_kwargs, **passed_pipe_kwargs} # remove `null` components def load_module(name, value): if value[0] is None: return False if name in passed_class_obj and passed_class_obj[name] is None: return False return True init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError( "The safety checker cannot be automatically loaded when loading weights `from_flax`." " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker" " separately if you need it." ) # 5. Throw nice warnings / errors for fast accelerate loading if len(unused_kwargs) > 0: logger.warning( f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." ) if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " install accelerate\n```\n." ) if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) if low_cpu_mem_usage is False and device_map is not None: raise ValueError( f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) # import it here to avoid circular import from diffusers import pipelines # 6. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names if class_name.startswith("Flax"): class_name = class_name[4:] # 6.2 Define all importable classes is_pipeline_module = hasattr(pipelines, library_name) importable_classes = ALL_IMPORTABLE_CLASSES loaded_sub_model = None # 6.3 Use passed sub model or load class_name from library_name if name in passed_class_obj: # if the model is in a pipeline module, then we load it from the pipeline # check that passed_class_obj has correct parent class maybe_raise_or_warn( library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module ) loaded_sub_model = passed_class_obj[name] else: # load sub model loaded_sub_model = load_sub_model( library_name=library_name, class_name=class_name, importable_classes=importable_classes, pipelines=pipelines, is_pipeline_module=is_pipeline_module, pipeline_class=pipeline_class, torch_dtype=torch_dtype, provider=provider, sess_options=sess_options, device_map=device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, model_variants=model_variants, name=name, from_flax=from_flax, variant=variant, low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, ) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) # 7. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) optional_modules = pipeline_class._optional_components if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): for module in missing_modules: init_kwargs[module] = passed_class_obj.get(module, None) elif len(missing_modules) > 0: passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs raise ValueError( f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) # 8. Instantiate the pipeline model = pipeline_class(**init_kwargs) return model @classmethod def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: r""" Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights. Parameters: pretrained_model_name (`str` or `os.PathLike`, *optional*): A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. custom_pipeline (`str`, *optional*): Can be either: - A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. The repository must contain a file called `pipeline.py` that defines the custom pipeline. - A string, the *file name* of a community pipeline hosted on GitHub under [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the current `main` branch of GitHub. - A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory must contain a file called `pipeline.py` that defines the custom pipeline. 🧪 This is an experimental feature and may change in the future. For more information on how to load and create custom pipelines, take a look at [How to contribute a community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. custom_revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. Returns: `os.PathLike`: A path to the downloaded pipeline. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_flax = kwargs.pop("from_flax", False) custom_pipeline = kwargs.pop("custom_pipeline", None) custom_revision = kwargs.pop("custom_revision", None) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True pipeline_is_cached = False allow_patterns = None ignore_patterns = None if not local_files_only: try: info = model_info( pretrained_model_name, use_auth_token=use_auth_token, revision=revision, ) except HTTPError as e: logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") local_files_only = True if not local_files_only: config_file = hf_hub_download( pretrained_model_name, cls.config_name, cache_dir=cache_dir, revision=revision, proxies=proxies, force_download=force_download, resume_download=resume_download, use_auth_token=use_auth_token, ) config_dict = cls._dict_from_json_file(config_file) ignore_filenames = config_dict.pop("_ignore_files", []) # retrieve all folder_names that contain relevant files folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] filenames = {sibling.rfilename for sibling in info.siblings} model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) if len(variant_filenames) == 0 and variant is not None: deprecation_message = ( f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" "if such variant modeling files are not available. Doing so will lead to an error in v0.22.0 as defaulting to non-variant" "modeling files is deprecated." ) deprecate("no variant default", "0.22.0", deprecation_message, standard_warn=False) # remove ignored filenames model_filenames = set(model_filenames) - set(ignore_filenames) variant_filenames = set(variant_filenames) - set(ignore_filenames) # if the whole pipeline is cached we don't have to ping the Hub if revision in DEPRECATED_REVISION_ARGS and version.parse( version.parse(__version__).base_version ) >= version.parse("0.20.0"): warn_deprecated_model_variant( pretrained_model_name, use_auth_token, variant, revision, model_filenames ) model_folder_names = {os.path.split(f)[0] for f in model_filenames} # all filenames compatible with variant will be added allow_patterns = list(model_filenames) # allow all patterns from non-model folders # this enables downloading schedulers, tokenizers, ... allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] # also allow downloading config.json files with the model allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] allow_patterns += [ SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name, CUSTOM_PIPELINE_FILE_NAME, ] # retrieve passed components that should not be downloaded pipeline_class = _get_pipeline_class( cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision ) expected_components, _ = cls._get_signature_keys(pipeline_class) passed_components = [k for k in expected_components if k in kwargs] if ( use_safetensors and not allow_pickle and not is_safetensors_compatible( model_filenames, variant=variant, passed_components=passed_components ) ): raise EnvironmentError( f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})" ) if from_flax: ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] elif use_safetensors and is_safetensors_compatible( model_filenames, variant=variant, passed_components=passed_components ): ignore_patterns = ["*.bin", "*.msgpack"] safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} if ( len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames ): logger.warn( f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." ) else: ignore_patterns = ["*.safetensors", "*.msgpack"] bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: logger.warn( f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." ) # Don't download any objects that are passed allow_patterns = [ p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) ] # Don't download index files of forbidden patterns either ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns] re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns] re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns] expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)] expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)] snapshot_folder = Path(config_file).parent pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files) if pipeline_is_cached and not force_download: # if the pipeline is cached, we can directly return it # else call snapshot_download return snapshot_folder user_agent = {"pipeline_class": cls.__name__} if custom_pipeline is not None and not custom_pipeline.endswith(".py"): user_agent["custom_pipeline"] = custom_pipeline # download all allow_patterns - ignore_patterns cached_folder = snapshot_download( pretrained_model_name, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, user_agent=user_agent, ) return cached_folder @staticmethod def _get_signature_keys(obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} return expected_modules, optional_parameters @property def components(self) -> Dict[str, Any]: r""" The `self.components` property can be useful to run different pipelines with the same weights and configurations without reallocating additional memory. Returns (`dict`): A dictionary containing all the modules needed to initialize the pipeline. Examples: ```py >>> from diffusers import ( ... StableDiffusionPipeline, ... StableDiffusionImg2ImgPipeline, ... StableDiffusionInpaintPipeline, ... ) >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) ``` """ expected_modules, optional_parameters = self._get_signature_keys(self) components = { k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters } if set(components.keys()) != expected_modules: raise ValueError( f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" f" {expected_modules} to be defined, but {components.keys()} are defined." ) return components @staticmethod def numpy_to_pil(images): """ Convert a NumPy image or a batch of images to a PIL image. """ return numpy_to_pil(images) def progress_bar(self, iterable=None, total=None): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): raise ValueError( f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) if iterable is not None: return tqdm(iterable, **self._progress_bar_config) elif total is not None: return tqdm(total=total, **self._progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): r""" Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed. ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent. Parameters: attention_op (`Callable`, *optional*): Override the default `None` operator for use as `op` argument to the [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) function of xFormers. Examples: ```py >>> import torch >>> from diffusers import DiffusionPipeline >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) >>> # Workaround for not accepting attention shape using VAE for Flash Attention >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None) ``` """ self.set_use_memory_efficient_attention_xformers(True, attention_op) def disable_xformers_memory_efficient_attention(self): r""" Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). """ self.set_use_memory_efficient_attention_xformers(False) def set_use_memory_efficient_attention_xformers( self, valid: bool, attention_op: Optional[Callable] = None ) -> None: # Recursively walk through all the children. # Any children which exposes the set_use_memory_efficient_attention_xformers method # gets the message def fn_recursive_set_mem_eff(module: torch.nn.Module): if hasattr(module, "set_use_memory_efficient_attention_xformers"): module.set_use_memory_efficient_attention_xformers(valid, attention_op) for child in module.children(): fn_recursive_set_mem_eff(child) module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] for module in modules: fn_recursive_set_mem_eff(module) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful to save some memory in exchange for a small speed decrease. Args: slice_size (`str` or `int`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ self.set_attention_slice(slice_size) def disable_attention_slicing(self): r""" Disable sliced attention computation. If `enable_attention_slicing` was previously called, attention is computed in one step. """ # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) def set_attention_slice(self, slice_size: Optional[int]): module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")] for module in modules: module.set_attention_slice(slice_size) ================================================ FILE: 4DoF/diffusers/pipelines/pndm/__init__.py ================================================ from .pipeline_pndm import PNDMPipeline ================================================ FILE: 4DoF/diffusers/pipelines/pndm/pipeline_pndm.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...models import UNet2DModel from ...schedulers import PNDMScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class PNDMPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. """ unet: UNet2DModel scheduler: PNDMScheduler def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): super().__init__() scheduler = PNDMScheduler.from_config(scheduler.config) self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 50, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: batch_size (`int`, `optional`, defaults to 1): The number of images to generate. num_inference_steps (`int`, `optional`, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator`, `optional`): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # For more information on the sampling method you can take a look at Algorithm 2 of # the official paper: https://arxiv.org/pdf/2202.09778.pdf # Sample gaussian noise to begin loop image = randn_tensor( (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, device=self.device, ) self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): model_output = self.unet(image, t).sample image = self.scheduler.step(model_output, t, image).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/repaint/__init__.py ================================================ from .pipeline_repaint import RePaintPipeline ================================================ FILE: 4DoF/diffusers/pipelines/repaint/pipeline_repaint.py ================================================ # Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import List, Optional, Tuple, Union import numpy as np import PIL import torch from ...models import UNet2DModel from ...schedulers import RePaintScheduler from ...utils import PIL_INTERPOLATION, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): if isinstance(mask, torch.Tensor): return mask elif isinstance(mask, PIL.Image.Image): mask = [mask] if isinstance(mask[0], PIL.Image.Image): w, h = mask[0].size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] mask = np.concatenate(mask, axis=0) mask = mask.astype(np.float32) / 255.0 mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) elif isinstance(mask[0], torch.Tensor): mask = torch.cat(mask, dim=0) return mask class RePaintPipeline(DiffusionPipeline): unet: UNet2DModel scheduler: RePaintScheduler def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, image: Union[torch.Tensor, PIL.Image.Image], mask_image: Union[torch.Tensor, PIL.Image.Image], num_inference_steps: int = 250, eta: float = 0.0, jump_length: int = 10, jump_n_sample: int = 10, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: image (`torch.FloatTensor` or `PIL.Image.Image`): The original image to inpaint on. mask_image (`torch.FloatTensor` or `PIL.Image.Image`): The mask_image where 0.0 values define which part of the original image to inpaint (change). num_inference_steps (`int`, *optional*, defaults to 1000): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. eta (`float`): The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM and 1.0 is DDPM scheduler respectively. jump_length (`int`, *optional*, defaults to 10): The number of steps taken forward in time before going backward in time for a single jump ("j" in RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. jump_n_sample (`int`, *optional*, defaults to 10): The number of times we will make forward time jump for a given chosen time sample. Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ original_image = image original_image = _preprocess_image(original_image) original_image = original_image.to(device=self.device, dtype=self.unet.dtype) mask_image = _preprocess_mask(mask_image) mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype) batch_size = original_image.shape[0] # sample gaussian noise to begin the loop if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) image_shape = original_image.shape image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) # set step values self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) self.scheduler.eta = eta t_last = self.scheduler.timesteps[0] + 1 generator = generator[0] if isinstance(generator, list) else generator for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): if t < t_last: # predict the noise residual model_output = self.unet(image, t).sample # compute previous image: x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample else: # compute the reverse: x_t-1 -> x_t image = self.scheduler.undo_step(image, t_last, generator) t_last = t image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/score_sde_ve/__init__.py ================================================ from .pipeline_score_sde_ve import ScoreSdeVePipeline ================================================ FILE: 4DoF/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...models import UNet2DModel from ...schedulers import ScoreSdeVeScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class ScoreSdeVePipeline(DiffusionPipeline): r""" Parameters: This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image. """ unet: UNet2DModel scheduler: ScoreSdeVeScheduler def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 2000, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) model = self.unet sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma sample = sample.to(self.device) self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps) for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) # correction step for _ in range(self.scheduler.config.correct_steps): model_output = self.unet(sample, sigma_t).sample sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample # prediction step model_output = model(sample, sigma_t).sample output = self.scheduler.step_pred(model_output, t, sample, generator=generator) sample, sample_mean = output.prev_sample, output.prev_sample_mean sample = sample_mean.clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": sample = self.numpy_to_pil(sample) if not return_dict: return (sample,) return ImagePipelineOutput(images=sample) ================================================ FILE: 4DoF/diffusers/pipelines/semantic_stable_diffusion/__init__.py ================================================ from dataclasses import dataclass from enum import Enum from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import BaseOutput, is_torch_available, is_transformers_available @dataclass class SemanticStableDiffusionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] if is_transformers_available() and is_torch_available(): from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline ================================================ FILE: 4DoF/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py ================================================ import inspect import warnings from itertools import repeat from typing import Callable, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import SemanticStableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import SemanticStableDiffusionPipeline >>> pipe = SemanticStableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> out = pipe( ... prompt="a photo of the face of a woman", ... num_images_per_prompt=1, ... guidance_scale=7, ... editing_prompt=[ ... "smiling, smile", # Concepts to apply ... "glasses, wearing glasses", ... "curls, wavy hair, curly hair", ... "beard, full beard, mustache", ... ], ... reverse_editing_direction=[ ... False, ... False, ... False, ... False, ... ], # Direction of guidance i.e. increase all concepts ... edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept ... edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept ... edit_threshold=[ ... 0.99, ... 0.975, ... 0.925, ... 0.96, ... ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions ... edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance ... edit_mom_beta=0.6, # Momentum beta ... edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other ... ) >>> image = out.images[0] ``` """ class SemanticStableDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation with latent editing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) This model builds on the implementation of ['StableDiffusionPipeline'] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`Q16SafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, editing_prompt: Optional[Union[str, List[str]]] = None, editing_prompt_embeddings: Optional[torch.Tensor] = None, reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, edit_guidance_scale: Optional[Union[float, List[float]]] = 5, edit_warmup_steps: Optional[Union[int, List[int]]] = 10, edit_cooldown_steps: Optional[Union[int, List[int]]] = None, edit_threshold: Optional[Union[float, List[float]]] = 0.9, edit_momentum_scale: Optional[float] = 0.1, edit_mom_beta: Optional[float] = 0.4, edit_weights: Optional[List[float]] = None, sem_guidance: Optional[List[torch.Tensor]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. editing_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`. editing_prompt_embeddings (`torch.Tensor>`, *optional*): Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be specified via `reverse_editing_direction`. reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): Whether the corresponding prompt in `editing_prompt` should be increased or decreased. edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 6 of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum will still be calculated for those steps and applied once all warmup periods are over. `edit_warmup_steps` is defined as `delta` (δ) of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied. edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): Threshold of semantic guidance. edit_momentum_scale (`float`, *optional*, defaults to 0.1): Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0 momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller than `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are finished. `edit_momentum_scale` is defined as `s_m` of equation 7 of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_mom_beta (`float`, *optional*, defaults to 0.4): Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller than `edit_warmup_steps`. `edit_mom_beta` is defined as `beta_m` (β) of equation 8 of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_weights (`List[float]`, *optional*, defaults to `None`): Indicates how much each individual concept should influence the overall guidance. If no weights are provided all concepts are applied equally. `edit_mom_beta` is defined as `g_i` of equation 9 of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). sem_guidance (`List[torch.Tensor]`, *optional*): List of pre-generated guidance vectors to be applied at generation. Length of the list has to correspond to `num_inference_steps`. Returns: [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) if editing_prompt: enable_edit_guidance = True if isinstance(editing_prompt, str): editing_prompt = [editing_prompt] enabled_editing_prompts = len(editing_prompt) elif editing_prompt_embeddings is not None: enable_edit_guidance = True enabled_editing_prompts = editing_prompt_embeddings.shape[0] else: enabled_editing_prompts = 0 enable_edit_guidance = False # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) text_input_ids = text_inputs.input_ids if text_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if enable_edit_guidance: # get safety text embeddings if editing_prompt_embeddings is None: edit_concepts_input = self.tokenizer( [x for item in editing_prompt for x in repeat(item, batch_size)], padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) edit_concepts_input_ids = edit_concepts_input.input_ids if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode( edit_concepts_input_ids[:, self.tokenizer.model_max_length :] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0] else: edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed_edit, seq_len_edit, _ = edit_concepts.shape edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1) edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if enable_edit_guidance: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts]) else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the initial random noise unless the user supplied it # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, text_embeddings.dtype, self.device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Initialize edit_momentum to None edit_momentum = None self.uncond_estimates = None self.text_estimates = None self.edit_estimates = None self.sem_guidance = None for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64] noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] noise_pred_edit_concepts = noise_pred_out[2:] # default text guidance noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond) # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0]) if self.uncond_estimates is None: self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape)) self.uncond_estimates[i] = noise_pred_uncond.detach().cpu() if self.text_estimates is None: self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) self.text_estimates[i] = noise_pred_text.detach().cpu() if self.edit_estimates is None and enable_edit_guidance: self.edit_estimates = torch.zeros( (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) ) if self.sem_guidance is None: self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) if edit_momentum is None: edit_momentum = torch.zeros_like(noise_guidance) if enable_edit_guidance: concept_weights = torch.zeros( (len(noise_pred_edit_concepts), noise_guidance.shape[0]), device=self.device, dtype=noise_guidance.dtype, ) noise_guidance_edit = torch.zeros( (len(noise_pred_edit_concepts), *noise_guidance.shape), device=self.device, dtype=noise_guidance.dtype, ) # noise_guidance_edit = torch.zeros_like(noise_guidance) warmup_inds = [] for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): self.edit_estimates[i, c] = noise_pred_edit_concept if isinstance(edit_guidance_scale, list): edit_guidance_scale_c = edit_guidance_scale[c] else: edit_guidance_scale_c = edit_guidance_scale if isinstance(edit_threshold, list): edit_threshold_c = edit_threshold[c] else: edit_threshold_c = edit_threshold if isinstance(reverse_editing_direction, list): reverse_editing_direction_c = reverse_editing_direction[c] else: reverse_editing_direction_c = reverse_editing_direction if edit_weights: edit_weight_c = edit_weights[c] else: edit_weight_c = 1.0 if isinstance(edit_warmup_steps, list): edit_warmup_steps_c = edit_warmup_steps[c] else: edit_warmup_steps_c = edit_warmup_steps if isinstance(edit_cooldown_steps, list): edit_cooldown_steps_c = edit_cooldown_steps[c] elif edit_cooldown_steps is None: edit_cooldown_steps_c = i + 1 else: edit_cooldown_steps_c = edit_cooldown_steps if i >= edit_warmup_steps_c: warmup_inds.append(c) if i >= edit_cooldown_steps_c: noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) continue noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) if reverse_editing_direction_c: noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 concept_weights[c, :] = tmp_weights noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c # torch.quantile function expects float32 if noise_guidance_edit_tmp.dtype == torch.float32: tmp = torch.quantile( torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), edit_threshold_c, dim=2, keepdim=False, ) else: tmp = torch.quantile( torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32), edit_threshold_c, dim=2, keepdim=False, ).to(noise_guidance_edit_tmp.dtype) noise_guidance_edit_tmp = torch.where( torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None], noise_guidance_edit_tmp, torch.zeros_like(noise_guidance_edit_tmp), ) noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp warmup_inds = torch.tensor(warmup_inds).to(self.device) if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: concept_weights = concept_weights.to("cpu") # Offload to cpu noise_guidance_edit = noise_guidance_edit.to("cpu") concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds) concept_weights_tmp = torch.where( concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp ) concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) noise_guidance_edit_tmp = torch.index_select( noise_guidance_edit.to(self.device), 0, warmup_inds ) noise_guidance_edit_tmp = torch.einsum( "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp ) noise_guidance_edit_tmp = noise_guidance_edit_tmp noise_guidance = noise_guidance + noise_guidance_edit_tmp self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu() del noise_guidance_edit_tmp del concept_weights_tmp concept_weights = concept_weights.to(self.device) noise_guidance_edit = noise_guidance_edit.to(self.device) concept_weights = torch.where( concept_weights < 0, torch.zeros_like(concept_weights), concept_weights ) concept_weights = torch.nan_to_num(concept_weights) noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit if warmup_inds.shape[0] == len(noise_pred_edit_concepts): noise_guidance = noise_guidance + noise_guidance_edit self.sem_guidance[i] = noise_guidance_edit.detach().cpu() if sem_guidance is not None: edit_guidance = sem_guidance[i].to(self.device) noise_guidance = noise_guidance + edit_guidance noise_pred = noise_pred_uncond + noise_guidance # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/shap_e/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline else: from .camera import create_pan_cameras from .pipeline_shap_e import ShapEPipeline from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline from .renderer import ( BoundingBoxVolume, ImportanceRaySampler, MLPNeRFModelOutput, MLPNeRSTFModel, ShapEParamsProjModel, ShapERenderer, StratifiedRaySampler, VoidNeRFModel, ) ================================================ FILE: 4DoF/diffusers/pipelines/shap_e/camera.py ================================================ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Tuple import numpy as np import torch @dataclass class DifferentiableProjectiveCamera: """ Implements a batch, differentiable, standard pinhole camera """ origin: torch.Tensor # [batch_size x 3] x: torch.Tensor # [batch_size x 3] y: torch.Tensor # [batch_size x 3] z: torch.Tensor # [batch_size x 3] width: int height: int x_fov: float y_fov: float shape: Tuple[int] def __post_init__(self): assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2 def resolution(self): return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) def fov(self): return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) def get_image_coords(self) -> torch.Tensor: """ :return: coords of shape (width * height, 2) """ pixel_indices = torch.arange(self.height * self.width) coords = torch.stack( [ pixel_indices % self.width, torch.div(pixel_indices, self.width, rounding_mode="trunc"), ], axis=1, ) return coords @property def camera_rays(self): batch_size, *inner_shape = self.shape inner_batch_size = int(np.prod(inner_shape)) coords = self.get_image_coords() coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) rays = self.get_camera_rays(coords) rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3) return rays def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor: batch_size, *shape, n_coords = coords.shape assert n_coords == 2 assert batch_size == self.origin.shape[0] flat = coords.view(batch_size, -1, 2) res = self.resolution() fov = self.fov() fracs = (flat.float() / (res - 1)) * 2 - 1 fracs = fracs * torch.tan(fov / 2) fracs = fracs.view(batch_size, -1, 2) directions = ( self.z.view(batch_size, 1, 3) + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] ) directions = directions / directions.norm(dim=-1, keepdim=True) rays = torch.stack( [ torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]), directions, ], dim=2, ) return rays.view(batch_size, *shape, 2, 3) def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": """ Creates a new camera for the resized view assuming the aspect ratio does not change. """ assert width * self.height == height * self.width, "The aspect ratio should not change." return DifferentiableProjectiveCamera( origin=self.origin, x=self.x, y=self.y, z=self.z, width=width, height=height, x_fov=self.x_fov, y_fov=self.y_fov, ) def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera: origins = [] xs = [] ys = [] zs = [] for theta in np.linspace(0, 2 * np.pi, num=20): z = np.array([np.sin(theta), np.cos(theta), -0.5]) z /= np.sqrt(np.sum(z**2)) origin = -z * 4 x = np.array([np.cos(theta), -np.sin(theta), 0.0]) y = np.cross(z, x) origins.append(origin) xs.append(x) ys.append(y) zs.append(z) return DifferentiableProjectiveCamera( origin=torch.from_numpy(np.stack(origins, axis=0)).float(), x=torch.from_numpy(np.stack(xs, axis=0)).float(), y=torch.from_numpy(np.stack(ys, axis=0)).float(), z=torch.from_numpy(np.stack(zs, axis=0)).float(), width=size, height=size, x_fov=0.7, y_fov=0.7, shape=(1, len(xs)), ) ================================================ FILE: 4DoF/diffusers/pipelines/shap_e/pipeline_shap_e.py ================================================ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPTextModelWithProjection, CLIPTokenizer from ...models import PriorTransformer from ...pipelines import DiffusionPipeline from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from .renderer import ShapERenderer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.utils import export_to_gif >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> repo = "openai/shap-e" >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> guidance_scale = 15.0 >>> prompt = "a shark" >>> images = pipe( ... prompt, ... guidance_scale=guidance_scale, ... num_inference_steps=64, ... frame_size=256, ... ).images >>> gif_path = export_to_gif(images[0], "shark_3d.gif") ``` """ @dataclass class ShapEPipelineOutput(BaseOutput): """ Output class for ShapEPipeline. Args: images (`torch.FloatTensor`) a list of images for 3D rendering """ images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]] class ShapEPipeline(DiffusionPipeline): """ Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`HeunDiscreteScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. renderer ([`ShapERenderer`]): Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects with the NeRF rendering method """ def __init__( self, prior: PriorTransformer, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: HeunDiscreteScheduler, renderer: ShapERenderer, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, renderer=renderer, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [self.text_encoder, self.prior] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): return self.device for module in self.text_encoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, ): len(prompt) if isinstance(prompt, list) else 1 # YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file self.tokenizer.pad_token_id = 0 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) # in Shap-E it normalize the prompt_embeds and then later rescale it prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(prompt_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # Rescale the features to have unit variance prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds return prompt_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: str, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, frame_size: int = 64, output_type: Optional[str] = "pil", # pil, np, latent return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. frame_size (`int`, *optional*, default to 64): the width and height of each image frame of the generated 3d output output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`ShapEPipelineOutput`] or `tuple` """ if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps num_embeddings = self.prior.config.num_embeddings embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, num_embeddings * embedding_dim), prompt_embeds.dtype, device, generator, latents, self.scheduler, ) # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) noise_pred = self.prior( scaled_model_input, timestep=t, proj_embedding=prompt_embeds, ).predicted_image_embedding # remove the variance noise_pred, _ = noise_pred.split( scaled_model_input.shape[2], dim=2 ) # batch_size, num_embeddings, embedding_dim if do_classifier_free_guidance is not None: noise_pred_uncond, noise_pred = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) latents = self.scheduler.step( noise_pred, timestep=t, sample=latents, ).prev_sample if output_type == "latent": return ShapEPipelineOutput(images=latents) images = [] for i, latent in enumerate(latents): image = self.renderer.decode( latent[None, :], device, size=frame_size, ray_batch_size=4096, n_coarse_samples=64, n_fine_samples=128, ) images.append(image) images = torch.stack(images) if output_type not in ["np", "pil"]: raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}") images = images.cpu().numpy() if output_type == "pil": images = [self.numpy_to_pil(image) for image in images] # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (images,) return ShapEPipelineOutput(images=images) ================================================ FILE: 4DoF/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py ================================================ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPVisionModel from ...models import PriorTransformer from ...pipelines import DiffusionPipeline from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, is_accelerate_available, logging, randn_tensor, replace_example_docstring, ) from .renderer import ShapERenderer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from PIL import Image >>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.utils import export_to_gif, load_image >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> repo = "openai/shap-e-img2img" >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> guidance_scale = 3.0 >>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png" >>> image = load_image(image_url).convert("RGB") >>> images = pipe( ... image, ... guidance_scale=guidance_scale, ... num_inference_steps=64, ... frame_size=256, ... ).images >>> gif_path = export_to_gif(images[0], "corgi_3d.gif") ``` """ @dataclass class ShapEPipelineOutput(BaseOutput): """ Output class for ShapEPipeline. Args: images (`torch.FloatTensor`) a list of images for 3D rendering """ images: Union[PIL.Image.Image, np.ndarray] class ShapEImg2ImgPipeline(DiffusionPipeline): """ Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`HeunDiscreteScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. renderer ([`ShapERenderer`]): Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects with the NeRF rendering method """ def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModel, image_processor: CLIPImageProcessor, scheduler: HeunDiscreteScheduler, renderer: ShapERenderer, ): super().__init__() self.register_modules( prior=prior, image_encoder=image_encoder, image_processor=image_processor, scheduler=scheduler, renderer=renderer, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [self.image_encoder, self.prior] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.image_encoder, "_hf_hook"): return self.device for module in self.image_encoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_image( self, image, device, num_images_per_prompt, do_classifier_free_guidance, ): if isinstance(image, List) and isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) if not isinstance(image, torch.Tensor): image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0) image = image.to(dtype=self.image_encoder.dtype, device=device) image_embeds = self.image_encoder(image)["last_hidden_state"] image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256 image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: negative_image_embeds = torch.zeros_like(image_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeds = torch.cat([negative_image_embeds, image_embeds]) return image_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image]], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, frame_size: int = 64, output_type: Optional[str] = "pil", # pil, np, latent return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. frame_size (`int`, *optional*, default to 64): the width and height of each image frame of the generated 3d output output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`ShapEPipelineOutput`] or `tuple` """ if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, torch.Tensor): batch_size = image.shape[0] elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)): batch_size = len(image) else: raise ValueError( f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `List[PIL.Image.Image]` or `List[torch.Tensor]` but is {type(image)}" ) device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps num_embeddings = self.prior.config.num_embeddings embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, num_embeddings * embedding_dim), image_embeds.dtype, device, generator, latents, self.scheduler, ) # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) noise_pred = self.prior( scaled_model_input, timestep=t, proj_embedding=image_embeds, ).predicted_image_embedding # remove the variance noise_pred, _ = noise_pred.split( scaled_model_input.shape[2], dim=2 ) # batch_size, num_embeddings, embedding_dim if do_classifier_free_guidance is not None: noise_pred_uncond, noise_pred = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) latents = self.scheduler.step( noise_pred, timestep=t, sample=latents, ).prev_sample if output_type == "latent": return ShapEPipelineOutput(images=latents) images = [] for i, latent in enumerate(latents): print() image = self.renderer.decode( latent[None, :], device, size=frame_size, ray_batch_size=4096, n_coarse_samples=64, n_fine_samples=128, ) images.append(image) images = torch.stack(images) if output_type not in ["np", "pil"]: raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}") images = images.cpu().numpy() if output_type == "pil": images = [self.numpy_to_pil(image) for image in images] # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (images,) return ShapEPipelineOutput(images=images) ================================================ FILE: 4DoF/diffusers/pipelines/shap_e/renderer.py ================================================ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...utils import BaseOutput from .camera import create_pan_cameras def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: r""" Sample from the given discrete probability distribution with replacement. The i-th bin is assumed to have mass pmf[i]. Args: pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() n_samples: number of samples Return: indices sampled with replacement """ *shape, support_size, last_dim = pmf.shape assert last_dim == 1 cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: """ Concatenate x and its positional encodings, following NeRF. Reference: https://arxiv.org/pdf/2210.04628.pdf """ if min_deg == max_deg: return x scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device) *shape, dim = x.shape xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) assert xb.shape[-1] == dim * (max_deg - min_deg) emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() return torch.cat([x, emb], dim=-1) def encode_position(position): return posenc_nerf(position, min_deg=0, max_deg=15) def encode_direction(position, direction=None): if direction is None: return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) else: return posenc_nerf(direction, min_deg=0, max_deg=8) def _sanitize_name(x: str) -> str: return x.replace(".", "__") def integrate_samples(volume_range, ts, density, channels): r""" Function integrating the model output. Args: volume_range: Specifies the integral range [t0, t1] ts: timesteps density: torch.Tensor [batch_size, *shape, n_samples, 1] channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] returns: channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume ) """ # 1. Calculate the weights _, _, dt = volume_range.partition(ts) ddensity = density * dt mass = torch.cumsum(ddensity, dim=-2) transmittance = torch.exp(-mass[..., -1, :]) alphas = 1.0 - torch.exp(-ddensity) Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) # This is the probability of light hitting and reflecting off of # something at depth [..., i, :]. weights = alphas * Ts # 2. Integrate channels channels = torch.sum(channels * weights, dim=-2) return channels, weights, transmittance class VoidNeRFModel(nn.Module): """ Implements the default empty space model where all queries are rendered as background. """ def __init__(self, background, channel_scale=255.0): super().__init__() background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale) self.register_buffer("background", background) def forward(self, position): background = self.background[None].to(position.device) shape = position.shape[:-1] ones = [1] * (len(shape) - 1) n_channels = background.shape[-1] background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]) return background @dataclass class VolumeRange: t0: torch.Tensor t1: torch.Tensor intersected: torch.Tensor def __post_init__(self): assert self.t0.shape == self.t1.shape == self.intersected.shape def partition(self, ts): """ Partitions t0 and t1 into n_samples intervals. Args: ts: [batch_size, *shape, n_samples, 1] Return: lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, *shape, n_samples, 1] where ts \\in [lower, upper] deltas = upper - lower """ mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 lower = torch.cat([self.t0[..., None, :], mids], dim=-2) upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) delta = upper - lower assert lower.shape == upper.shape == delta.shape == ts.shape return lower, upper, delta class BoundingBoxVolume(nn.Module): """ Axis-aligned bounding box defined by the two opposite corners. """ def __init__( self, *, bbox_min, bbox_max, min_dist: float = 0.0, min_t_range: float = 1e-3, ): """ Args: bbox_min: the left/bottommost corner of the bounding box bbox_max: the other corner of the bounding box min_dist: all rays should start at least this distance away from the origin. """ super().__init__() self.min_dist = min_dist self.min_t_range = min_t_range self.bbox_min = torch.tensor(bbox_min) self.bbox_max = torch.tensor(bbox_max) self.bbox = torch.stack([self.bbox_min, self.bbox_max]) assert self.bbox.shape == (2, 3) assert min_dist >= 0.0 assert min_t_range > 0.0 def intersect( self, origin: torch.Tensor, direction: torch.Tensor, t0_lower: Optional[torch.Tensor] = None, epsilon=1e-6, ): """ Args: origin: [batch_size, *shape, 3] direction: [batch_size, *shape, 3] t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. params: Optional meta parameters in case Volume is parametric epsilon: to stabilize calculations Return: A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to be on the boundary of the volume. """ batch_size, *shape, _ = origin.shape ones = [1] * len(shape) bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device) def _safe_divide(a, b, epsilon=1e-6): return a / torch.where(b < 0, b - epsilon, b + epsilon) ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) # Cases to think about: # # 1. t1 <= t0: the ray does not pass through the AABB. # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. # 3. t0 <= 0 <= t1: the ray starts from inside the BB # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. # # 1 and 4 are clearly handled from t0 < t1 below. # Making t0 at least min_dist (>= 0) takes care of 2 and 3. t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values assert t0.shape == t1.shape == (batch_size, *shape, 1) if t0_lower is not None: assert t0.shape == t0_lower.shape t0 = torch.maximum(t0, t0_lower) intersected = t0 + self.min_t_range < t1 t0 = torch.where(intersected, t0, torch.zeros_like(t0)) t1 = torch.where(intersected, t1, torch.ones_like(t1)) return VolumeRange(t0=t0, t1=t1, intersected=intersected) class StratifiedRaySampler(nn.Module): """ Instead of fixed intervals, a sample is drawn uniformly at random from each interval. """ def __init__(self, depth_mode: str = "linear"): """ :param depth_mode: linear samples ts linearly in depth. harmonic ensures closer points are sampled more densely. """ self.depth_mode = depth_mode assert self.depth_mode in ("linear", "geometric", "harmonic") def sample( self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int, epsilon: float = 1e-3, ) -> torch.Tensor: """ Args: t0: start time has shape [batch_size, *shape, 1] t1: finish time has shape [batch_size, *shape, 1] n_samples: number of ts to sample Return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ ones = [1] * (len(t0.shape) - 1) ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) if self.depth_mode == "linear": ts = t0 * (1.0 - ts) + t1 * ts elif self.depth_mode == "geometric": ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() elif self.depth_mode == "harmonic": # The original NeRF recommends this interpolation scheme for # spherical scenes, but there could be some weird edge cases when # the observer crosses from the inner to outer volume. ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) upper = torch.cat([mids, t1], dim=-1) lower = torch.cat([t0, mids], dim=-1) # yiyi notes: add a random seed here for testing, don't forget to remove torch.manual_seed(0) t_rand = torch.rand_like(ts) ts = lower + (upper - lower) * t_rand return ts.unsqueeze(-1) class ImportanceRaySampler(nn.Module): """ Given the initial estimate of densities, this samples more from regions/bins expected to have objects. """ def __init__( self, volume_range: VolumeRange, ts: torch.Tensor, weights: torch.Tensor, blur_pool: bool = False, alpha: float = 1e-5, ): """ Args: volume_range: the range in which a ray intersects the given volume. ts: earlier samples from the coarse rendering step weights: discretized version of density * transmittance blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. alpha: small value to add to weights. """ self.volume_range = volume_range self.ts = ts.clone().detach() self.weights = weights.clone().detach() self.blur_pool = blur_pool self.alpha = alpha @torch.no_grad() def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: """ Args: t0: start time has shape [batch_size, *shape, 1] t1: finish time has shape [batch_size, *shape, 1] n_samples: number of ts to sample Return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ lower, upper, _ = self.volume_range.partition(self.ts) batch_size, *shape, n_coarse_samples, _ = self.ts.shape weights = self.weights if self.blur_pool: padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) weights = weights + self.alpha pmf = weights / weights.sum(dim=-2, keepdim=True) inds = sample_pmf(pmf, n_samples) assert inds.shape == (batch_size, *shape, n_samples, 1) assert (inds >= 0).all() and (inds < n_coarse_samples).all() t_rand = torch.rand(inds.shape, device=inds.device) lower_ = torch.gather(lower, -2, inds) upper_ = torch.gather(upper, -2, inds) ts = lower_ + (upper_ - lower_) * t_rand ts = torch.sort(ts, dim=-2).values return ts @dataclass class MLPNeRFModelOutput(BaseOutput): density: torch.Tensor signed_distance: torch.Tensor channels: torch.Tensor ts: torch.Tensor class MLPNeRSTFModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, d_hidden: int = 256, n_output: int = 12, n_hidden_layers: int = 6, act_fn: str = "swish", insert_direction_at: int = 4, ): super().__init__() # Instantiate the MLP # Find out the dimension of encoded position and direction dummy = torch.eye(1, 3) d_posenc_pos = encode_position(position=dummy).shape[-1] d_posenc_dir = encode_direction(position=dummy).shape[-1] mlp_widths = [d_hidden] * n_hidden_layers input_widths = [d_posenc_pos] + mlp_widths output_widths = mlp_widths + [n_output] if insert_direction_at is not None: input_widths[insert_direction_at] += d_posenc_dir self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)]) if act_fn == "swish": # self.activation = swish # yiyi testing: self.activation = lambda x: F.silu(x) else: raise ValueError(f"Unsupported activation function {act_fn}") self.sdf_activation = torch.tanh self.density_activation = torch.nn.functional.relu self.channel_activation = torch.sigmoid def map_indices_to_keys(self, output): h_map = { "sdf": (0, 1), "density_coarse": (1, 2), "density_fine": (2, 3), "stf": (3, 6), "nerf_coarse": (6, 9), "nerf_fine": (9, 12), } mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()} return mapped_output def forward(self, *, position, direction, ts, nerf_level="coarse"): h = encode_position(position) h_preact = h h_directionless = None for i, layer in enumerate(self.mlp): if i == self.config.insert_direction_at: # 4 in the config h_directionless = h_preact h_direction = encode_direction(position, direction=direction) h = torch.cat([h, h_direction], dim=-1) h = layer(h) h_preact = h if i < len(self.mlp) - 1: h = self.activation(h) h_final = h if h_directionless is None: h_directionless = h_preact activation = self.map_indices_to_keys(h_final) if nerf_level == "coarse": h_density = activation["density_coarse"] h_channels = activation["nerf_coarse"] else: h_density = activation["density_fine"] h_channels = activation["nerf_fine"] density = self.density_activation(h_density) signed_distance = self.sdf_activation(activation["sdf"]) channels = self.channel_activation(h_channels) # yiyi notes: I think signed_distance is not used return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts) class ChannelsProj(nn.Module): def __init__( self, *, vectors: int, channels: int, d_latent: int, ): super().__init__() self.proj = nn.Linear(d_latent, vectors * channels) self.norm = nn.LayerNorm(channels) self.d_latent = d_latent self.vectors = vectors self.channels = channels def forward(self, x: torch.Tensor) -> torch.Tensor: x_bvd = x w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) b_vc = self.proj.bias.view(1, self.vectors, self.channels) h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) h = self.norm(h) h = h + b_vc return h class ShapEParamsProjModel(ModelMixin, ConfigMixin): """ project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). For more details, see the original paper: """ @register_to_config def __init__( self, *, param_names: Tuple[str] = ( "nerstf.mlp.0.weight", "nerstf.mlp.1.weight", "nerstf.mlp.2.weight", "nerstf.mlp.3.weight", ), param_shapes: Tuple[Tuple[int]] = ( (256, 93), (256, 256), (256, 256), (256, 256), ), d_latent: int = 1024, ): super().__init__() # check inputs if len(param_names) != len(param_shapes): raise ValueError("Must provide same number of `param_names` as `param_shapes`") self.projections = nn.ModuleDict({}) for k, (vectors, channels) in zip(param_names, param_shapes): self.projections[_sanitize_name(k)] = ChannelsProj( vectors=vectors, channels=channels, d_latent=d_latent, ) def forward(self, x: torch.Tensor): out = {} start = 0 for k, shape in zip(self.config.param_names, self.config.param_shapes): vectors, _ = shape end = start + vectors x_bvd = x[:, start:end] out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) start = end return out class ShapERenderer(ModelMixin, ConfigMixin): @register_to_config def __init__( self, *, param_names: Tuple[str] = ( "nerstf.mlp.0.weight", "nerstf.mlp.1.weight", "nerstf.mlp.2.weight", "nerstf.mlp.3.weight", ), param_shapes: Tuple[Tuple[int]] = ( (256, 93), (256, 256), (256, 256), (256, 256), ), d_latent: int = 1024, d_hidden: int = 256, n_output: int = 12, n_hidden_layers: int = 6, act_fn: str = "swish", insert_direction_at: int = 4, background: Tuple[float] = ( 255.0, 255.0, 255.0, ), ): super().__init__() self.params_proj = ShapEParamsProjModel( param_names=param_names, param_shapes=param_shapes, d_latent=d_latent, ) self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) self.void = VoidNeRFModel(background=background, channel_scale=255.0) self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) @torch.no_grad() def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): """ Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below with some abuse of notations) C(r) := sum( transmittance(t[i]) * integrate( lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], ) for i in range(len(parts)) ) + transmittance(t[-1]) * void_model(t[-1]).channels where 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). args: rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples: number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including :return: A tuple of - `channels` - A importance samplers for additional fine-grained rendering - raw model output """ origin, direction = rays[..., 0, :], rays[..., 1, :] # Integrate over [t[i], t[i + 1]] # 1 Intersect the rays with the current volume and sample ts to integrate along. vrange = self.volume.intersect(origin, direction, t0_lower=None) ts = sampler.sample(vrange.t0, vrange.t1, n_samples) ts = ts.to(rays.dtype) if prev_model_out is not None: # Append the previous ts now before fprop because previous # rendering used a different model and we can't reuse the output. ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values batch_size, *_shape, _t0_dim = vrange.t0.shape _, *ts_shape, _ts_dim = ts.shape # 2. Get the points along the ray and query the model directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) positions = origin.unsqueeze(-2) + ts * directions directions = directions.to(self.mlp.dtype) positions = positions.to(self.mlp.dtype) optional_directions = directions if render_with_direction else None model_out = self.mlp( position=positions, direction=optional_directions, ts=ts, nerf_level="coarse" if prev_model_out is None else "fine", ) # 3. Integrate the model results channels, weights, transmittance = integrate_samples( vrange, model_out.ts, model_out.density, model_out.channels ) # 4. Clean up results that do not intersect with the volume. transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance)) channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels)) # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). channels = channels + transmittance * self.void(origin) weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights) return channels, weighted_sampler, model_out @torch.no_grad() def decode( self, latents, device, size: int = 64, ray_batch_size: int = 4096, n_coarse_samples=64, n_fine_samples=128, ): # project the the paramters from the generated latents projected_params = self.params_proj(latents) # update the mlp layers of the renderer for name, param in self.mlp.state_dict().items(): if f"nerstf.{name}" in projected_params.keys(): param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) # create cameras object camera = create_pan_cameras(size) rays = camera.camera_rays rays = rays.to(device) n_batches = rays.shape[1] // ray_batch_size coarse_sampler = StratifiedRaySampler() images = [] for idx in range(n_batches): rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] # render rays with coarse, stratified samples. _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples) # Then, render with additional importance-weighted ray samples. channels, _, _ = self.render_rays( rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out ) images.append(channels) images = torch.cat(images, dim=1) images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) return images ================================================ FILE: 4DoF/diffusers/pipelines/spectrogram_diffusion/__init__.py ================================================ # flake8: noqa from ...utils import is_note_seq_available, is_transformers_available, is_torch_available from ...utils import OptionalDependencyNotAvailable try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .notes_encoder import SpectrogramNotesEncoder from .continous_encoder import SpectrogramContEncoder from .pipeline_spectrogram_diffusion import ( SpectrogramContEncoder, SpectrogramDiffusionPipeline, T5FilmDecoder, ) try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: from .midi_utils import MidiProcessor ================================================ FILE: 4DoF/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.t5.modeling_t5 import ( T5Block, T5Config, T5LayerNorm, ) from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): @register_to_config def __init__( self, input_dims: int, targets_context_length: int, d_model: int, dropout_rate: float, num_layers: int, num_heads: int, d_kv: int, d_ff: int, feed_forward_proj: str, is_decoder: bool = False, ): super().__init__() self.input_proj = nn.Linear(input_dims, d_model, bias=False) self.position_encoding = nn.Embedding(targets_context_length, d_model) self.position_encoding.weight.requires_grad = False self.dropout_pre = nn.Dropout(p=dropout_rate) t5config = T5Config( d_model=d_model, num_heads=num_heads, d_kv=d_kv, d_ff=d_ff, feed_forward_proj=feed_forward_proj, dropout_rate=dropout_rate, is_decoder=is_decoder, is_encoder_decoder=False, ) self.encoders = nn.ModuleList() for lyr_num in range(num_layers): lyr = T5Block(t5config) self.encoders.append(lyr) self.layer_norm = T5LayerNorm(d_model) self.dropout_post = nn.Dropout(p=dropout_rate) def forward(self, encoder_inputs, encoder_inputs_mask): x = self.input_proj(encoder_inputs) # terminal relative positional encodings max_positions = encoder_inputs.shape[1] input_positions = torch.arange(max_positions, device=encoder_inputs.device) seq_lens = encoder_inputs_mask.sum(-1) input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) x += self.position_encoding(input_positions) x = self.dropout_pre(x) # inverted the attention mask input_shape = encoder_inputs.size() extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) for lyr in self.encoders: x = lyr(x, extended_attention_mask)[0] x = self.layer_norm(x) return self.dropout_post(x), encoder_inputs_mask ================================================ FILE: 4DoF/diffusers/pipelines/spectrogram_diffusion/midi_utils.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import math import os from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn.functional as F from ...utils import is_note_seq_available from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH if is_note_seq_available(): import note_seq else: raise ImportError("Please install note-seq via `pip install note-seq`") INPUT_FEATURE_LENGTH = 2048 SAMPLE_RATE = 16000 HOP_SIZE = 320 FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE) DEFAULT_STEPS_PER_SECOND = 100 DEFAULT_MAX_SHIFT_SECONDS = 10 DEFAULT_NUM_VELOCITY_BINS = 1 SLAKH_CLASS_PROGRAMS = { "Acoustic Piano": 0, "Electric Piano": 4, "Chromatic Percussion": 8, "Organ": 16, "Acoustic Guitar": 24, "Clean Electric Guitar": 26, "Distorted Electric Guitar": 29, "Acoustic Bass": 32, "Electric Bass": 33, "Violin": 40, "Viola": 41, "Cello": 42, "Contrabass": 43, "Orchestral Harp": 46, "Timpani": 47, "String Ensemble": 48, "Synth Strings": 50, "Choir and Voice": 52, "Orchestral Hit": 55, "Trumpet": 56, "Trombone": 57, "Tuba": 58, "French Horn": 60, "Brass Section": 61, "Soprano/Alto Sax": 64, "Tenor Sax": 66, "Baritone Sax": 67, "Oboe": 68, "English Horn": 69, "Bassoon": 70, "Clarinet": 71, "Pipe": 73, "Synth Lead": 80, "Synth Pad": 88, } @dataclasses.dataclass class NoteRepresentationConfig: """Configuration note representations.""" onsets_only: bool include_ties: bool @dataclasses.dataclass class NoteEventData: pitch: int velocity: Optional[int] = None program: Optional[int] = None is_drum: Optional[bool] = None instrument: Optional[int] = None @dataclasses.dataclass class NoteEncodingState: """Encoding state for note transcription, keeping track of active pitches.""" # velocity bin for active pitches and programs active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(default_factory=dict) @dataclasses.dataclass class EventRange: type: str min_value: int max_value: int @dataclasses.dataclass class Event: type: str value: int class Tokenizer: def __init__(self, regular_ids: int): # The special tokens: 0=PAD, 1=EOS, and 2=UNK self._num_special_tokens = 3 self._num_regular_tokens = regular_ids def encode(self, token_ids): encoded = [] for token_id in token_ids: if not 0 <= token_id < self._num_regular_tokens: raise ValueError( f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})" ) encoded.append(token_id + self._num_special_tokens) # Add EOS token encoded.append(1) # Pad to till INPUT_FEATURE_LENGTH encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded)) return encoded class Codec: """Encode and decode events. Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not include things like EOS or UNK token handling. To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required and specified separately. """ def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: List[EventRange]): """Define Codec. Args: max_shift_steps: Maximum number of shift steps that can be encoded. steps_per_second: Shift steps will be interpreted as having a duration of 1 / steps_per_second. event_ranges: Other supported event types and their ranges. """ self.steps_per_second = steps_per_second self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps) self._event_ranges = [self._shift_range] + event_ranges # Ensure all event types have unique names. assert len(self._event_ranges) == len({er.type for er in self._event_ranges}) @property def num_classes(self) -> int: return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) # The next couple methods are simplified special case methods just for shift # events that are intended to be used from within autograph functions. def is_shift_event_index(self, index: int) -> bool: return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value) @property def max_shift_steps(self) -> int: return self._shift_range.max_value def encode_event(self, event: Event) -> int: """Encode an event to an index.""" offset = 0 for er in self._event_ranges: if event.type == er.type: if not er.min_value <= event.value <= er.max_value: raise ValueError( f"Event value {event.value} is not within valid range " f"[{er.min_value}, {er.max_value}] for type {event.type}" ) return offset + event.value - er.min_value offset += er.max_value - er.min_value + 1 raise ValueError(f"Unknown event type: {event.type}") def event_type_range(self, event_type: str) -> Tuple[int, int]: """Return [min_id, max_id] for an event type.""" offset = 0 for er in self._event_ranges: if event_type == er.type: return offset, offset + (er.max_value - er.min_value) offset += er.max_value - er.min_value + 1 raise ValueError(f"Unknown event type: {event_type}") def decode_event_index(self, index: int) -> Event: """Decode an event index to an Event.""" offset = 0 for er in self._event_ranges: if offset <= index <= offset + er.max_value - er.min_value: return Event(type=er.type, value=er.min_value + index - offset) offset += er.max_value - er.min_value + 1 raise ValueError(f"Unknown event index: {index}") @dataclasses.dataclass class ProgramGranularity: # both tokens_map_fn and program_map_fn should be idempotent tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]] program_map_fn: Callable[[int], int] def drop_programs(tokens, codec: Codec): """Drops program change events from a token sequence.""" min_program_id, max_program_id = codec.event_type_range("program") return tokens[(tokens < min_program_id) | (tokens > max_program_id)] def programs_to_midi_classes(tokens, codec): """Modifies program events to be the first program in the MIDI class.""" min_program_id, max_program_id = codec.event_type_range("program") is_program = (tokens >= min_program_id) & (tokens <= max_program_id) return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens) PROGRAM_GRANULARITIES = { # "flat" granularity; drop program change tokens and set NoteSequence # programs to zero "flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0), # map each program to the first program in its MIDI class "midi_class": ProgramGranularity( tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8) ), # leave programs as is "full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program), } def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): """ equivalent of tf.signal.frame """ signal_length = signal.shape[axis] if pad_end: frames_overlap = frame_length - frame_step rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap) pad_size = int(frame_length - rest_samples) if pad_size != 0: pad_axis = [0] * signal.ndim pad_axis[axis] = pad_size signal = F.pad(signal, pad_axis, "constant", pad_value) frames = signal.unfold(axis, frame_length, frame_step) return frames def program_to_slakh_program(program): # this is done very hackily, probably should use a custom mapping for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True): if program >= slakh_program: return slakh_program def audio_to_frames( samples, hop_size: int, frame_rate: int, ) -> Tuple[Sequence[Sequence[int]], torch.Tensor]: """Convert audio samples to non-overlapping frames and frame times.""" frame_size = hop_size samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant") # Split audio into frames. frames = frame( torch.Tensor(samples).unsqueeze(0), frame_length=frame_size, frame_step=frame_size, pad_end=False, # TODO check why its off by 1 here when True ) num_frames = len(samples) // frame_size times = np.arange(num_frames) / frame_rate return frames, times def note_sequence_to_onsets_and_offsets_and_programs( ns: note_seq.NoteSequence, ) -> Tuple[Sequence[float], Sequence[NoteEventData]]: """Extract onset & offset times and pitches & programs from a NoteSequence. The onset & offset times will not necessarily be in sorted order. Args: ns: NoteSequence from which to extract onsets and offsets. Returns: times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for note offsets. """ # Sort by program and pitch and put offsets before onsets as a tiebreaker for # subsequent stable sort. notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch)) times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes] values = [ NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False) for note in notes if not note.is_drum ] + [ NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum) for note in notes ] return times, values def num_velocity_bins_from_codec(codec: Codec): """Get number of velocity bins from event codec.""" lo, hi = codec.event_type_range("velocity") return hi - lo # segment an array into segments of length n def segment(a, n): return [a[i : i + n] for i in range(0, len(a), n)] def velocity_to_bin(velocity, num_velocity_bins): if velocity == 0: return 0 else: return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) def note_event_data_to_events( state: Optional[NoteEncodingState], value: NoteEventData, codec: Codec, ) -> Sequence[Event]: """Convert note event data to a sequence of events.""" if value.velocity is None: # onsets only, no program or velocity return [Event("pitch", value.pitch)] else: num_velocity_bins = num_velocity_bins_from_codec(codec) velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins) if value.program is None: # onsets + offsets + velocities only, no programs if state is not None: state.active_pitches[(value.pitch, 0)] = velocity_bin return [Event("velocity", velocity_bin), Event("pitch", value.pitch)] else: if value.is_drum: # drum events use a separate vocabulary return [Event("velocity", velocity_bin), Event("drum", value.pitch)] else: # program + velocity + pitch if state is not None: state.active_pitches[(value.pitch, value.program)] = velocity_bin return [ Event("program", value.program), Event("velocity", velocity_bin), Event("pitch", value.pitch), ] def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]: """Output program and pitch events for active notes plus a final tie event.""" events = [] for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]): if state.active_pitches[(pitch, program)]: events += [Event("program", program), Event("pitch", pitch)] events.append(Event("tie", 0)) return events def encode_and_index_events( state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None ): """Encode a sequence of timed events and index to audio frame times. Encodes time shifts as repeated single step shifts for later run length encoding. Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio frame. This can be used e.g. to prepend events representing the current state to a targets segment. Args: state: Initial event encoding state. event_times: Sequence of event times. event_values: Sequence of event values. encode_event_fn: Function that transforms event value into a sequence of one or more Event objects. codec: An Codec object that maps Event objects to indices. frame_times: Time for every audio frame. encoding_state_to_events_fn: Function that transforms encoding state into a sequence of one or more Event objects. Returns: events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame. Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of another. event_end_indices: Corresponding end event index for every audio frame. Used to ensure when slicing that one chunk ends where the next begins. Should always be true that event_end_indices[i] = event_start_indices[i + 1]. state_events: Encoded "state" events representing the encoding state before each event. state_event_indices: Corresponding state event index for every audio frame. """ indices = np.argsort(event_times, kind="stable") event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices] event_values = [event_values[i] for i in indices] events = [] state_events = [] event_start_indices = [] state_event_indices = [] cur_step = 0 cur_event_idx = 0 cur_state_event_idx = 0 def fill_event_start_indices_to_cur_step(): while ( len(event_start_indices) < len(frame_times) and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second ): event_start_indices.append(cur_event_idx) state_event_indices.append(cur_state_event_idx) for event_step, event_value in zip(event_steps, event_values): while event_step > cur_step: events.append(codec.encode_event(Event(type="shift", value=1))) cur_step += 1 fill_event_start_indices_to_cur_step() cur_event_idx = len(events) cur_state_event_idx = len(state_events) if encoding_state_to_events_fn: # Dump state to state events *before* processing the next event, because # we want to capture the state prior to the occurrence of the event. for e in encoding_state_to_events_fn(state): state_events.append(codec.encode_event(e)) for e in encode_event_fn(state, event_value, codec): events.append(codec.encode_event(e)) # After the last event, continue filling out the event_start_indices array. # The inequality is not strict because if our current step lines up exactly # with (the start of) an audio frame, we need to add an additional shift event # to "cover" that frame. while cur_step / codec.steps_per_second <= frame_times[-1]: events.append(codec.encode_event(Event(type="shift", value=1))) cur_step += 1 fill_event_start_indices_to_cur_step() cur_event_idx = len(events) # Now fill in event_end_indices. We need this extra array to make sure that # when we slice events, each slice ends exactly where the subsequent slice # begins. event_end_indices = event_start_indices[1:] + [len(events)] events = np.array(events).astype(np.int32) state_events = np.array(state_events).astype(np.int32) event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH) event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH) state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH) outputs = [] for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices): outputs.append( { "inputs": events, "event_start_indices": start_indices, "event_end_indices": end_indices, "state_events": state_events, "state_event_indices": event_indices, } ) return outputs def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"): """Extract target sequence corresponding to audio token segment.""" features = features.copy() start_idx = features["event_start_indices"][0] end_idx = features["event_end_indices"][-1] features[feature_key] = features[feature_key][start_idx:end_idx] if state_events_end_token is not None: # Extract the state events corresponding to the audio start token, and # prepend them to the targets array. state_event_start_idx = features["state_event_indices"][0] state_event_end_idx = state_event_start_idx + 1 while features["state_events"][state_event_end_idx - 1] != state_events_end_token: state_event_end_idx += 1 features[feature_key] = np.concatenate( [ features["state_events"][state_event_start_idx:state_event_end_idx], features[feature_key], ], axis=0, ) return features def map_midi_programs( feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs" ) -> Mapping[str, Any]: """Apply MIDI program map to token sequences.""" granularity = PROGRAM_GRANULARITIES[granularity_type] feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec) return feature def run_length_encode_shifts_fn( features, codec: Codec, feature_key: str = "inputs", state_change_event_types: Sequence[str] = (), ) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: """Return a function that run-length encodes shifts for a given codec. Args: codec: The Codec to use for shift events. feature_key: The feature key for which to run-length encode shifts. state_change_event_types: A list of event types that represent state changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones will be removed. Returns: A preprocessing function that run-length encodes single-step shifts. """ state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types] def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]: """Combine leading/interior shifts, trim trailing shifts. Args: features: Dict of features to process. Returns: A dict of features. """ events = features[feature_key] shift_steps = 0 total_shift_steps = 0 output = np.array([], dtype=np.int32) current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32) for event in events: if codec.is_shift_event_index(event): shift_steps += 1 total_shift_steps += 1 else: # If this event is a state change and has the same value as the current # state, we can skip it entirely. is_redundant = False for i, (min_index, max_index) in enumerate(state_change_event_ranges): if (min_index <= event) and (event <= max_index): if current_state[i] == event: is_redundant = True current_state[i] = event if is_redundant: continue # Once we've reached a non-shift event, RLE all previous shift events # before outputting the non-shift event. if shift_steps > 0: shift_steps = total_shift_steps while shift_steps > 0: output_steps = np.minimum(codec.max_shift_steps, shift_steps) output = np.concatenate([output, [output_steps]], axis=0) shift_steps -= output_steps output = np.concatenate([output, [event]], axis=0) features[feature_key] = output return features return run_length_encode_shifts(features) def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig): tie_token = codec.encode_event(Event("tie", 0)) state_events_end_token = tie_token if note_representation_config.include_ties else None features = extract_sequence_with_indices( features, state_events_end_token=state_events_end_token, feature_key="inputs" ) features = map_midi_programs(features, codec) features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"]) return features class MidiProcessor: def __init__(self): self.codec = Codec( max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND, steps_per_second=DEFAULT_STEPS_PER_SECOND, event_ranges=[ EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS), EventRange("tie", 0, 0), EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM), EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), ], ) self.tokenizer = Tokenizer(self.codec.num_classes) self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True) def __call__(self, midi: Union[bytes, os.PathLike, str]): if not isinstance(midi, bytes): with open(midi, "rb") as f: midi = f.read() ns = note_seq.midi_to_note_sequence(midi) ns_sus = note_seq.apply_sustain_control_changes(ns) for note in ns_sus.notes: if not note.is_drum: note.program = program_to_slakh_program(note.program) samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE)) _, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE) times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus) events = encode_and_index_events( state=NoteEncodingState(), event_times=times, event_values=values, frame_times=frame_times, codec=self.codec, encode_event_fn=note_event_data_to_events, encoding_state_to_events_fn=note_encoding_state_to_events, ) events = [ note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events ] input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events] return input_tokens ================================================ FILE: 4DoF/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): @register_to_config def __init__( self, max_length: int, vocab_size: int, d_model: int, dropout_rate: float, num_layers: int, num_heads: int, d_kv: int, d_ff: int, feed_forward_proj: str, is_decoder: bool = False, ): super().__init__() self.token_embedder = nn.Embedding(vocab_size, d_model) self.position_encoding = nn.Embedding(max_length, d_model) self.position_encoding.weight.requires_grad = False self.dropout_pre = nn.Dropout(p=dropout_rate) t5config = T5Config( vocab_size=vocab_size, d_model=d_model, num_heads=num_heads, d_kv=d_kv, d_ff=d_ff, dropout_rate=dropout_rate, feed_forward_proj=feed_forward_proj, is_decoder=is_decoder, is_encoder_decoder=False, ) self.encoders = nn.ModuleList() for lyr_num in range(num_layers): lyr = T5Block(t5config) self.encoders.append(lyr) self.layer_norm = T5LayerNorm(d_model) self.dropout_post = nn.Dropout(p=dropout_rate) def forward(self, encoder_input_tokens, encoder_inputs_mask): x = self.token_embedder(encoder_input_tokens) seq_length = encoder_input_tokens.shape[1] inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) x += self.position_encoding(inputs_positions) x = self.dropout_pre(x) # inverted the attention mask input_shape = encoder_input_tokens.size() extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) for lyr in self.encoders: x = lyr(x, extended_attention_mask)[0] x = self.layer_norm(x) return self.dropout_post(x), encoder_inputs_mask ================================================ FILE: 4DoF/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch from ...models import T5FilmDecoder from ...schedulers import DDPMScheduler from ...utils import is_onnx_available, logging, randn_tensor if is_onnx_available(): from ..onnx_utils import OnnxRuntimeModel from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .continous_encoder import SpectrogramContEncoder from .notes_encoder import SpectrogramNotesEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name TARGET_FEATURE_LENGTH = 256 class SpectrogramDiffusionPipeline(DiffusionPipeline): _optional_components = ["melgan"] def __init__( self, notes_encoder: SpectrogramNotesEncoder, continuous_encoder: SpectrogramContEncoder, decoder: T5FilmDecoder, scheduler: DDPMScheduler, melgan: OnnxRuntimeModel if is_onnx_available() else Any, ) -> None: super().__init__() # From MELGAN self.min_value = math.log(1e-5) # Matches MelGAN training. self.max_value = 4.0 # Largest value for most examples self.n_dims = 128 self.register_modules( notes_encoder=notes_encoder, continuous_encoder=continuous_encoder, decoder=decoder, scheduler=scheduler, melgan=melgan, ) def scale_features(self, features, output_range=(-1.0, 1.0), clip=False): """Linearly scale features to network outputs range.""" min_out, max_out = output_range if clip: features = torch.clip(features, self.min_value, self.max_value) # Scale to [0, 1]. zero_one = (features - self.min_value) / (self.max_value - self.min_value) # Scale to [min_out, max_out]. return zero_one * (max_out - min_out) + min_out def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=False): """Invert by linearly scaling network outputs to features range.""" min_out, max_out = input_range outputs = torch.clip(outputs, min_out, max_out) if clip else outputs # Scale to [0, 1]. zero_one = (outputs - min_out) / (max_out - min_out) # Scale to [self.min_value, self.max_value]. return zero_one * (self.max_value - self.min_value) + self.min_value def encode(self, input_tokens, continuous_inputs, continuous_mask): tokens_mask = input_tokens > 0 tokens_encoded, tokens_mask = self.notes_encoder( encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask ) continuous_encoded, continuous_mask = self.continuous_encoder( encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask ) return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)] def decode(self, encodings_and_masks, input_tokens, noise_time): timesteps = noise_time if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=input_tokens.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(input_tokens.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(input_tokens.shape[0], dtype=timesteps.dtype, device=timesteps.device) logits = self.decoder( encodings_and_masks=encodings_and_masks, decoder_input_tokens=input_tokens, decoder_noise_time=timesteps ) return logits @torch.no_grad() def __call__( self, input_tokens: List[List[int]], generator: Optional[torch.Generator] = None, num_inference_steps: int = 100, return_dict: bool = True, output_type: str = "numpy", callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ) -> Union[AudioPipelineOutput, Tuple]: if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) pred_mel = np.zeros([1, TARGET_FEATURE_LENGTH, self.n_dims], dtype=np.float32) full_pred_mel = np.zeros([1, 0, self.n_dims], np.float32) ones = torch.ones((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) for i, encoder_input_tokens in enumerate(input_tokens): if i == 0: encoder_continuous_inputs = torch.from_numpy(pred_mel[:1].copy()).to( device=self.device, dtype=self.decoder.dtype ) # The first chunk has no previous context. encoder_continuous_mask = torch.zeros((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) else: # The full song pipeline does not feed in a context feature, so the mask # will be all 0s after the feature converter. Because we know we're # feeding in a full context chunk from the previous prediction, set it # to all 1s. encoder_continuous_mask = ones encoder_continuous_inputs = self.scale_features( encoder_continuous_inputs, output_range=[-1.0, 1.0], clip=True ) encodings_and_masks = self.encode( input_tokens=torch.IntTensor([encoder_input_tokens]).to(device=self.device), continuous_inputs=encoder_continuous_inputs, continuous_mask=encoder_continuous_mask, ) # Sample encoder_continuous_inputs shaped gaussian noise to begin loop x = randn_tensor( shape=encoder_continuous_inputs.shape, generator=generator, device=self.device, dtype=self.decoder.dtype, ) # set step values self.scheduler.set_timesteps(num_inference_steps) # Denoising diffusion loop for j, t in enumerate(self.progress_bar(self.scheduler.timesteps)): output = self.decode( encodings_and_masks=encodings_and_masks, input_tokens=x, noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1) ) # Compute previous output: x_t -> x_t-1 x = self.scheduler.step(output, t, x, generator=generator).prev_sample mel = self.scale_to_features(x, input_range=[-1.0, 1.0]) encoder_continuous_inputs = mel[:1] pred_mel = mel.cpu().float().numpy() full_pred_mel = np.concatenate([full_pred_mel, pred_mel[:1]], axis=1) # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, full_pred_mel) logger.info("Generated segment", i) if output_type == "numpy" and not is_onnx_available(): raise ValueError( "Cannot return output in 'np' format if ONNX is not available. Make sure to have ONNX installed or set 'output_type' to 'mel'." ) elif output_type == "numpy" and self.melgan is None: raise ValueError( "Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'." ) if output_type == "numpy": output = self.melgan(input_features=full_pred_mel.astype(np.float32)) else: output = full_pred_mel if not return_dict: return (output,) return AudioPipelineOutput(audios=output) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import ( BaseOutput, OptionalDependencyNotAvailable, is_flax_available, is_k_diffusion_available, is_k_diffusion_version, is_onnx_available, is_torch_available, is_transformers_available, is_transformers_version, ) @dataclass class StableDiffusionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline from .safety_checker import StableDiffusionSafetyChecker from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline else: from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionPix2PixZeroPipeline, ) else: from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline try: if not ( is_torch_available() and is_transformers_available() and is_k_diffusion_available() and is_k_diffusion_version(">=", "0.0.12") ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 else: from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline try: if not (is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_onnx_objects import * # noqa F403 else: from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline if is_transformers_available() and is_flax_available(): import flax @flax.struct.dataclass class FlaxStableDiffusionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`np.ndarray`) Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content. """ images: np.ndarray nsfw_content_detected: List[bool] from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Conversion script for the Stable Diffusion checkpoints.""" import re from io import BytesIO from typing import Optional import requests import torch from transformers import ( AutoFeatureExtractor, BertTokenizerFast, CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionConfig, CLIPVisionModelWithProjection, ) from ...models import ( AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel, ) from ...schedulers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UnCLIPScheduler, ) from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging from ...utils.import_utils import BACKENDS_MAPPING from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..paint_by_example import PaintByExampleImageEncoder from ..pipeline_utils import DiffusionPipeline from .safety_checker import StableDiffusionSafetyChecker from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer if is_accelerate_available(): from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) # pylint: disable=invalid-name def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. """ if n_shave_prefix_segments >= 0: return ".".join(path.split(".")[n_shave_prefix_segments:]) else: return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item.replace("in_layers.0", "norm1") new_item = new_item.replace("in_layers.2", "conv1") new_item = new_item.replace("out_layers.0", "norm2") new_item = new_item.replace("out_layers.3", "conv2") new_item = new_item.replace("emb_layers.1", "time_emb_proj") new_item = new_item.replace("skip_connection", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("nin_shortcut", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item # new_item = new_item.replace('norm.weight', 'group_norm.weight') # new_item = new_item.replace('norm.bias', 'group_norm.bias') # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("q.weight", "to_q.weight") new_item = new_item.replace("q.bias", "to_q.bias") new_item = new_item.replace("k.weight", "to_k.weight") new_item = new_item.replace("k.bias", "to_k.bias") new_item = new_item.replace("v.weight", "to_v.weight") new_item = new_item.replace("v.bias", "to_v.bias") new_item = new_item.replace("proj_out.weight", "to_out.0.weight") new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def assign_to_checkpoint( paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new checkpoint. """ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." # Splits the attention layers into three variables. if attention_paths_to_split is not None: for path, path_map in attention_paths_to_split.items(): old_tensor = old_checkpoint[path] channels = old_tensor.shape[0] // 3 target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) query, key, value = old_tensor.split(channels // num_heads, dim=1) checkpoint[path_map["query"]] = query.reshape(target_shape) checkpoint[path_map["key"]] = key.reshape(target_shape) checkpoint[path_map["value"]] = value.reshape(target_shape) for path in paths: new_path = path["new"] # These have already been assigned if attention_paths_to_split is not None and new_path in attention_paths_to_split: continue # Global renaming happens here new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) shape = old_checkpoint[path["old"]].shape if is_attn_weight and len(shape) == 3: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] elif is_attn_weight and len(shape) == 4: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): keys = list(checkpoint.keys()) attn_keys = ["query.weight", "key.weight", "value.weight"] for key in keys: if ".".join(key.split(".")[-2:]) in attn_keys: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0, 0] elif "proj_attn.weight" in key: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0] def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): """ Creates a config for the diffusers based on the config of the LDM model. """ if controlnet: unet_params = original_config.model.params.control_stage_config.params else: if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: unet_params = original_config.model.params.unet_config.params else: unet_params = original_config.model.params.network_config.params vae_params = original_config.model.params.first_stage_config.params.ddconfig block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 if unet_params.transformer_depth is not None: transformer_layers_per_block = ( unet_params.transformer_depth if isinstance(unet_params.transformer_depth, int) else list(unet_params.transformer_depth) ) else: transformer_layers_per_block = 1 vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) head_dim = unet_params.num_heads if "num_heads" in unet_params else None use_linear_projection = ( unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: head_dim_mult = unet_params.model_channels // unet_params.num_head_channels head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] class_embed_type = None addition_embed_type = None addition_time_embed_dim = None projection_class_embeddings_input_dim = None context_dim = None if unet_params.context_dim is not None: context_dim = ( unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] ) if "num_classes" in unet_params: if unet_params.num_classes == "sequential": if context_dim in [2048, 1280]: # SDXL addition_embed_type = "text_time" addition_time_embed_dim = 256 else: class_embed_type = "projection" assert "adm_in_channels" in unet_params projection_class_embeddings_input_dim = unet_params.adm_in_channels else: raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") config = { "sample_size": image_size // vae_scale_factor, "in_channels": unet_params.in_channels, "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), "layers_per_block": unet_params.num_res_blocks, "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "addition_embed_type": addition_embed_type, "addition_time_embed_dim": addition_time_embed_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, "transformer_layers_per_block": transformer_layers_per_block, } if controlnet: config["conditioning_channels"] = unet_params.hint_channels else: config["out_channels"] = unet_params.out_channels config["up_block_types"] = tuple(up_block_types) return config def create_vae_diffusers_config(original_config, image_size: int): """ Creates a config for the diffusers based on the config of the LDM model. """ vae_params = original_config.model.params.first_stage_config.params.ddconfig _ = original_config.model.params.first_stage_config.params.embed_dim block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, "in_channels": vae_params.in_channels, "out_channels": vae_params.out_ch, "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), "latent_channels": vae_params.z_channels, "layers_per_block": vae_params.num_res_blocks, } return config def create_diffusers_schedular(original_config): schedular = DDIMScheduler( num_train_timesteps=original_config.model.params.timesteps, beta_start=original_config.model.params.linear_start, beta_end=original_config.model.params.linear_end, beta_schedule="scaled_linear", ) return schedular def create_ldm_bert_config(original_config): bert_params = original_config.model.parms.cond_stage_config.params config = LDMBertConfig( d_model=bert_params.n_embed, encoder_layers=bert_params.n_layer, encoder_ffn_dim=bert_params.n_embed * 4, ) return config def convert_ldm_unet_checkpoint( checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False ): """ Takes a state dict and a config, and returns a converted checkpoint. """ if skip_extract_state_dict: unet_state_dict = checkpoint else: # extract state_dict for UNet unet_state_dict = {} keys = list(checkpoint.keys()) if controlnet: unet_key = "control_model." else: unet_key = "model.diffusion_model." # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") logger.warning( "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." ) for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) else: if sum(k.startswith("model_ema") for k in keys) > 100: logger.warning( "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" " weights (usually better for inference), please make sure to add the `--extract_ema` flag." ) for key in keys: if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] if config["class_embed_type"] is None: # No parameters to port ... elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] else: raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") if config["addition_embed_type"] == "text_time": new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] if not controlnet: new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) input_blocks = { layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) middle_blocks = { layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) output_blocks = { layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] for layer_id in range(num_output_blocks) } for i in range(1, num_input_blocks): block_id = (i - 1) // (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.weight" ) new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.bias" ) paths = renew_resnet_paths(resnets) meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) if len(attentions): paths = renew_attention_paths(attentions) meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) resnet_0 = middle_blocks[0] attentions = middle_blocks[1] resnet_1 = middle_blocks[2] resnet_0_paths = renew_resnet_paths(resnet_0) assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) resnet_1_paths = renew_resnet_paths(resnet_1) assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) for i in range(num_output_blocks): block_id = i // (config["layers_per_block"] + 1) layer_in_block_id = i % (config["layers_per_block"] + 1) output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] output_block_list = {} for layer in output_block_layers: layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) if layer_id in output_block_list: output_block_list[layer_id].append(layer_name) else: output_block_list[layer_id] = [layer_name] if len(output_block_list) > 1: resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.weight" ] new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.bias" ] # Clear attentions as they have been attributed above. if len(attentions) == 2: attentions = [] if len(attentions): paths = renew_attention_paths(attentions) meta_path = { "old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) new_checkpoint[new_path] = unet_state_dict[old_path] if controlnet: # conditioning embedding orig_index = 0 new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) orig_index += 2 diffusers_index = 0 while diffusers_index < 6: new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) diffusers_index += 1 orig_index += 2 new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) # down blocks for i in range(num_input_blocks): new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") # mid block new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config): # extract state dict for VAE vae_state_dict = {} vae_key = "first_stage_model." keys = list(checkpoint.keys()) for key in keys: if key.startswith(vae_key): vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) new_checkpoint = {} new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] # Retrieves the keys for the encoder down blocks only num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) down_blocks = { layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) up_blocks = { layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) } for i in range(num_down_blocks): resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.weight" ) new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.bias" ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.weight" ] new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.bias" ] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) return new_checkpoint def convert_ldm_bert_checkpoint(checkpoint, config): def _copy_attn_layer(hf_attn_layer, pt_attn_layer): hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias def _copy_linear(hf_linear, pt_linear): hf_linear.weight = pt_linear.weight hf_linear.bias = pt_linear.bias def _copy_layer(hf_layer, pt_layer): # copy layer norms _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) # copy attn _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) # copy MLP pt_mlp = pt_layer[1][1] _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) _copy_linear(hf_layer.fc2, pt_mlp.net[2]) def _copy_layers(hf_layers, pt_layers): for i, hf_layer in enumerate(hf_layers): if i != 0: i += i pt_layer = pt_layers[i : i + 2] _copy_layer(hf_layer, pt_layer) hf_model = LDMBertModel(config).eval() # copy embeds hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight # copy layer norm _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) # copy hidden layers _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) return hf_model def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): if text_encoder is None: config_name = "openai/clip-vit-large-patch14" config = CLIPTextConfig.from_pretrained(config_name) with init_empty_weights(): text_model = CLIPTextModel(config) keys = list(checkpoint.keys()) text_model_dict = {} remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] for key in keys: for prefix in remove_prefixes: if key.startswith(prefix): text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] for param_name, param in text_model_dict.items(): set_module_tensor_to_device(text_model, param_name, "cpu", value=param) return text_model textenc_conversion_lst = [ ("positional_embedding", "text_model.embeddings.position_embedding.weight"), ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), ("ln_final.weight", "text_model.final_layer_norm.weight"), ("ln_final.bias", "text_model.final_layer_norm.bias"), ("text_projection", "text_projection.weight"), ] textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} textenc_transformer_conversion_lst = [ # (stable-diffusion, HF Diffusers) ("resblocks.", "text_model.encoder.layers."), ("ln_1", "layer_norm1"), ("ln_2", "layer_norm2"), (".c_fc.", ".fc1."), (".c_proj.", ".fc2."), (".attn", ".self_attn"), ("ln_final.", "transformer.text_model.final_layer_norm."), ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), ] protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) def convert_paint_by_example_checkpoint(checkpoint): config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") model = PaintByExampleImageEncoder(config) keys = list(checkpoint.keys()) text_model_dict = {} for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] # load clip vision model.model.load_state_dict(text_model_dict) # load mapper keys_mapper = { k[len("cond_stage_model.mapper.res") :]: v for k, v in checkpoint.items() if k.startswith("cond_stage_model.mapper") } MAPPING = { "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], "attn.c_proj": ["attn1.to_out.0"], "ln_1": ["norm1"], "ln_2": ["norm3"], "mlp.c_fc": ["ff.net.0.proj"], "mlp.c_proj": ["ff.net.2"], } mapped_weights = {} for key, value in keys_mapper.items(): prefix = key[: len("blocks.i")] suffix = key.split(prefix)[-1].split(".")[-1] name = key.split(prefix)[-1].split(suffix)[0][1:-1] mapped_names = MAPPING[name] num_splits = len(mapped_names) for i, mapped_name in enumerate(mapped_names): new_name = ".".join([prefix, mapped_name, suffix]) shape = value.shape[0] // num_splits mapped_weights[new_name] = value[i * shape : (i + 1) * shape] model.mapper.load_state_dict(mapped_weights) # load final layer norm model.final_layer_norm.load_state_dict( { "bias": checkpoint["cond_stage_model.final_ln.bias"], "weight": checkpoint["cond_stage_model.final_ln.weight"], } ) # load final proj model.proj_out.load_state_dict( { "bias": checkpoint["proj_out.bias"], "weight": checkpoint["proj_out.weight"], } ) # load uncond vector model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) return model def convert_open_clip_checkpoint( checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs ): # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") # text_model = CLIPTextModelWithProjection.from_pretrained( # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 # ) config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) with init_empty_weights(): text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) keys = list(checkpoint.keys()) keys_to_ignore = [] if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: # make sure to remove all keys > 22 keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] keys_to_ignore += ["cond_stage_model.model.text_projection"] text_model_dict = {} if prefix + "text_projection" in checkpoint: d_model = int(checkpoint[prefix + "text_projection"].shape[0]) else: d_model = 1024 text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") for key in keys: if key in keys_to_ignore: continue if key[len(prefix) :] in textenc_conversion_map: if key.endswith("text_projection"): value = checkpoint[key].T else: value = checkpoint[key] text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value if key.startswith(prefix + "transformer."): new_key = key[len(prefix + "transformer.") :] if new_key.endswith(".in_proj_weight"): new_key = new_key[: -len(".in_proj_weight")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] elif new_key.endswith(".in_proj_bias"): new_key = new_key[: -len(".in_proj_bias")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] else: new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key] = checkpoint[key] for param_name, param in text_model_dict.items(): set_module_tensor_to_device(text_model, param_name, "cpu", value=param) return text_model def stable_unclip_image_encoder(original_config): """ Returns the image processor and clip image encoder for the img2img unclip pipeline. We currently know of two types of stable unclip models which separately use the clip and the openclip image encoders. """ image_embedder_config = original_config.model.params.embedder_config sd_clip_image_embedder_class = image_embedder_config.target sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] if sd_clip_image_embedder_class == "ClipImageEmbedder": clip_model_name = image_embedder_config.params.model if clip_model_name == "ViT-L/14": feature_extractor = CLIPImageProcessor() image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") else: raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": feature_extractor = CLIPImageProcessor() image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") else: raise NotImplementedError( f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" ) return feature_extractor, image_encoder def stable_unclip_image_noising_components( original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None ): """ Returns the noising components for the img2img and txt2img unclip pipelines. Converts the stability noise augmentor into 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats 2. a `DDPMScheduler` for holding the noise schedule If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. """ noise_aug_config = original_config.model.params.noise_aug_config noise_aug_class = noise_aug_config.target noise_aug_class = noise_aug_class.split(".")[-1] if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": noise_aug_config = noise_aug_config.params embedding_dim = noise_aug_config.timestep_dim max_noise_level = noise_aug_config.noise_schedule_config.timesteps beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) if "clip_stats_path" in noise_aug_config: if clip_stats_path is None: raise ValueError("This stable unclip config requires a `clip_stats_path`") clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) clip_mean = clip_mean[None, :] clip_std = clip_std[None, :] clip_stats_state_dict = { "mean": clip_mean, "std": clip_std, } image_normalizer.load_state_dict(clip_stats_state_dict) else: raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") return image_normalizer, image_noising_scheduler def convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema, use_linear_projection=None, cross_attention_dim=None, ): ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config.pop("sample_size") if use_linear_projection is not None: ctrlnet_config["use_linear_projection"] = use_linear_projection if cross_attention_dim is not None: ctrlnet_config["cross_attention_dim"] = cross_attention_dim controlnet_model = ControlNetModel(**ctrlnet_config) # Some controlnet ckpt files are distributed independently from the rest of the # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ if "time_embed.0.weight" in checkpoint: skip_extract_state_dict = True else: skip_extract_state_dict = False converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True, skip_extract_state_dict=skip_extract_state_dict, ) controlnet_model.load_state_dict(converted_ctrl_checkpoint) return controlnet_model def download_from_original_stable_diffusion_ckpt( checkpoint_path: str, original_config_file: str = None, image_size: Optional[int] = None, prediction_type: str = None, model_type: str = None, extract_ema: bool = False, scheduler_type: str = "pndm", num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, device: str = None, from_safetensors: bool = False, stable_unclip: Optional[str] = None, stable_unclip_prior: Optional[str] = None, clip_stats_path: Optional[str] = None, controlnet: Optional[bool] = None, load_safety_checker: bool = True, pipeline_class: DiffusionPipeline = None, local_files_only=False, vae_path=None, text_encoder=None, tokenizer=None, ) -> DiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file. Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is recommended that you override the default values and/or supply an `original_config_file` wherever possible. Args: checkpoint_path (`str`): Path to `.ckpt` file. original_config_file (`str`): Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models. image_size (`int`, *optional*, defaults to 512): The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 Base. Use 768 for Stable Diffusion v2. prediction_type (`str`, *optional*): The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. num_in_channels (`int`, *optional*, defaults to None): The number of input channels. If `None`, it will be automatically inferred. scheduler_type (`str`, *optional*, defaults to 'pndm'): Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", "ddim"]`. model_type (`str`, *optional*, defaults to `None`): The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. is_img2img (`bool`, *optional*, defaults to `False`): Whether the model should be loaded as an img2img pipeline. extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning. upcast_attention (`bool`, *optional*, defaults to `None`): Whether the attention computation should always be upcasted. This is necessary when running stable diffusion 2.1. device (`str`, *optional*, defaults to `None`): The device to use. Pass `None` to determine automatically. from_safetensors (`str`, *optional*, defaults to `False`): If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. load_safety_checker (`bool`, *optional*, defaults to `True`): Whether to load the safety checker or not. Defaults to `True`. pipeline_class (`str`, *optional*, defaults to `None`): The pipeline class to use. Pass `None` to determine automatically. local_files_only (`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): An instance of [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if needed. return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ # import pipelines here to avoid circular import error when using from_single_file method from diffusers import ( LDMTextToImagePipeline, PaintByExamplePipeline, StableDiffusionControlNetPipeline, StableDiffusionInpaintPipeline, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) if pipeline_class is None: pipeline_class = StableDiffusionPipeline if prediction_type == "v-prediction": prediction_type = "v_prediction" if not is_omegaconf_available(): raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) from omegaconf import OmegaConf if from_safetensors: if not is_safetensors_available(): raise ValueError(BACKENDS_MAPPING["safetensors"][1]) from safetensors.torch import load_file as safe_load checkpoint = safe_load(checkpoint_path, device="cpu") else: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) # Sometimes models don't have the global_step item if "global_step" in checkpoint: global_step = checkpoint["global_step"] else: logger.debug("global_step key not found in model") global_step = None # NOTE: this while loop isn't great but this controlnet checkpoint has one additional # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] if original_config_file is None: key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" # model_type = "v1" config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: # model_type = "v2" config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" if global_step == 110000: # v2.1 needs to upcast attention upcast_attention = True elif key_name_sd_xl_base in checkpoint: # only base xl has two text embedders config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" elif key_name_sd_xl_refiner in checkpoint: # only refiner xl has embedder and one text embedders config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" original_config_file = BytesIO(requests.get(config_url).content) original_config = OmegaConf.load(original_config_file) # Convert the text model. if ( model_type is None and "cond_stage_config" in original_config.model.params and original_config.model.params.cond_stage_config is not None ): model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") elif model_type is None and original_config.model.params.network_config is not None: if original_config.model.params.network_config.params.context_dim == 2048: model_type = "SDXL" else: model_type = "SDXL-Refiner" if image_size is None: image_size = 1024 if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: num_in_channels = 9 elif num_in_channels is None: num_in_channels = 4 if "unet_config" in original_config.model.params: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if ( "parameterization" in original_config["model"]["params"] and original_config["model"]["params"]["parameterization"] == "v" ): if prediction_type is None: # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here prediction_type = "epsilon" if global_step == 875000 else "v_prediction" if image_size is None: # NOTE: For stable diffusion 2 base one has to pass `image_size==512` # as it relies on a brittle global step parameter here image_size = 512 if global_step == 875000 else 768 else: if prediction_type is None: prediction_type = "epsilon" if image_size is None: image_size = 512 if controlnet is None: controlnet = "control_stage_config" in original_config.model.params if controlnet: controlnet_model = convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema ) num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 if model_type in ["SDXL", "SDXL-Refiner"]: scheduler_dict = { "beta_schedule": "scaled_linear", "beta_start": 0.00085, "beta_end": 0.012, "interpolation_type": "linear", "num_train_timesteps": num_train_timesteps, "prediction_type": "epsilon", "sample_max_value": 1.0, "set_alpha_to_one": False, "skip_prk_steps": True, "steps_offset": 1, "timestep_spacing": "leading", } scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) scheduler_type = "euler" else: beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 scheduler = DDIMScheduler( beta_end=beta_end, beta_schedule="scaled_linear", beta_start=beta_start, num_train_timesteps=num_train_timesteps, steps_offset=1, clip_sample=False, set_alpha_to_one=False, prediction_type=prediction_type, ) # make sure scheduler works correctly with DDIM scheduler.register_to_config(clip_sample=False) if scheduler_type == "pndm": config = dict(scheduler.config) config["skip_prk_steps"] = True scheduler = PNDMScheduler.from_config(config) elif scheduler_type == "lms": scheduler = LMSDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "heun": scheduler = HeunDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "euler": scheduler = EulerDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "euler-ancestral": scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "dpm": scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) elif scheduler_type == "ddim": scheduler = scheduler else: raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config["upcast_attention"] = upcast_attention with init_empty_weights(): unet = UNet2DConditionModel(**unet_config) converted_unet_checkpoint = convert_ldm_unet_checkpoint( checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema ) for param_name, param in converted_unet_checkpoint.items(): set_module_tensor_to_device(unet, param_name, "cpu", value=param) # Convert the VAE model. if vae_path is None: vae_config = create_vae_diffusers_config(original_config, image_size=image_size) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) if ( "model" in original_config and "params" in original_config.model and "scale_factor" in original_config.model.params ): vae_scaling_factor = original_config.model.params.scale_factor else: vae_scaling_factor = 0.18215 # default SD scaling factor vae_config["scaling_factor"] = vae_scaling_factor with init_empty_weights(): vae = AutoencoderKL(**vae_config) for param_name, param in converted_vae_checkpoint.items(): set_module_tensor_to_device(vae, param_name, "cpu", value=param) else: vae = AutoencoderKL.from_pretrained(vae_path) if model_type == "FrozenOpenCLIPEmbedder": config_name = "stabilityai/stable-diffusion-2" config_kwargs = {"subfolder": "text_encoder"} text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") if stable_unclip is None: if controlnet: pipe = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, controlnet=controlnet_model, safety_checker=None, feature_extractor=None, requires_safety_checker=False, ) else: pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, ) else: image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( original_config, clip_stats_path=clip_stats_path, device=device ) if stable_unclip == "img2img": feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) pipe = StableUnCLIPImg2ImgPipeline( # image encoding components feature_extractor=feature_extractor, image_encoder=image_encoder, # image noising components image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, # regular denoising components tokenizer=tokenizer, text_encoder=text_model, unet=unet, scheduler=scheduler, # vae vae=vae, ) elif stable_unclip == "txt2img": if stable_unclip_prior is None or stable_unclip_prior == "karlo": karlo_model = "kakaobrain/karlo-v1-alpha" prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) else: raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") pipe = StableUnCLIPPipeline( # prior components prior_tokenizer=prior_tokenizer, prior_text_encoder=prior_text_model, prior=prior, prior_scheduler=prior_scheduler, # image noising components image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, # regular denoising components tokenizer=tokenizer, text_encoder=text_model, unet=unet, scheduler=scheduler, # vae vae=vae, ) else: raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") elif model_type == "PaintByExample": vision_model = convert_paint_by_example_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") pipe = PaintByExamplePipeline( vae=vae, image_encoder=vision_model, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=feature_extractor, ) elif model_type == "FrozenCLIPEmbedder": text_model = convert_ldm_clip_checkpoint( checkpoint, local_files_only=local_files_only, text_encoder=text_encoder ) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if tokenizer is None else tokenizer if load_safety_checker: safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") else: safety_checker = None feature_extractor = None if controlnet: pipe = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, controlnet=controlnet_model, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) else: pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) elif model_type in ["SDXL", "SDXL-Refiner"]: if model_type == "SDXL": tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" config_kwargs = {"projection_dim": 1280} text_encoder_2 = convert_open_clip_checkpoint( checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs ) pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, force_zeros_for_empty_prompt=True, ) else: tokenizer = None text_encoder = None tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" config_kwargs = {"projection_dim": 1280} text_encoder_2 = convert_open_clip_checkpoint( checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs ) pipe = StableDiffusionXLImg2ImgPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, requires_aesthetics_score=True, force_zeros_for_empty_prompt=False, ) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) return pipe def download_controlnet_from_original_ckpt( checkpoint_path: str, original_config_file: str, image_size: int = 512, extract_ema: bool = False, num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, device: str = None, from_safetensors: bool = False, use_linear_projection: Optional[bool] = None, cross_attention_dim: Optional[bool] = None, ) -> DiffusionPipeline: if not is_omegaconf_available(): raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) from omegaconf import OmegaConf if from_safetensors: if not is_safetensors_available(): raise ValueError(BACKENDS_MAPPING["safetensors"][1]) from safetensors import safe_open checkpoint = {} with safe_open(checkpoint_path, framework="pt", device="cpu") as f: for key in f.keys(): checkpoint[key] = f.get_tensor(key) else: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) # NOTE: this while loop isn't great but this controlnet checkpoint has one additional # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] original_config = OmegaConf.load(original_config_file) if num_in_channels is not None: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if "control_stage_config" not in original_config.model.params: raise ValueError("`control_stage_config` not present in original config") controlnet_model = convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema, use_linear_projection=use_linear_projection, cross_attention_dim=cross_attention_dim, ) return controlnet_model ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): # 1. get previous step value (=t-1) prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps if prev_timestep <= 0: return clean_latents # 2. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = ( scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod ) variance = scheduler._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # direction pointing to x_t e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t noise = std_dev_t * randn_tensor( clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator ) prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise return prev_latents def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): # 1. get previous step value (=t-1) prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps # 2. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = ( scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) # 4. Clip "predicted x_0" if scheduler.config.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = scheduler._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / ( variance ** (0.5) * eta ) return noise class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): image = image.to(device=device, dtype=dtype) batch_size = image.shape[0] if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) # add noise to latents using the timestep shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents clean_latents = init_latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents, clean_latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], source_prompt: Union[str, List[str]], image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, source_guidance_scale: Optional[float] = 1, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. source_guidance_scale (`float`, *optional*, defaults to 1): Guidance scale for the source prompt. This is useful to control the amount of influence the source prompt for encoding. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.1): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs(prompt, strength, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, prompt_embeds=prompt_embeds, lora_scale=text_encoder_lora_scale, ) source_prompt_embeds = self._encode_prompt( source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents, clean_latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) source_latents = latents # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) generator = extra_step_kwargs.pop("generator", None) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) source_latent_model_input = torch.cat([source_latents] * 2) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) # predict the noise residual concat_latent_model_input = torch.stack( [ source_latent_model_input[0], latent_model_input[0], source_latent_model_input[1], latent_model_input[1], ], dim=0, ) concat_prompt_embeds = torch.stack( [ source_prompt_embeds[0], prompt_embeds[0], source_prompt_embeds[1], prompt_embeds[1], ], dim=0, ) concat_noise_pred = self.unet( concat_latent_model_input, t, cross_attention_kwargs=cross_attention_kwargs, encoder_hidden_states=concat_prompt_embeds, ).sample # perform guidance ( source_noise_pred_uncond, noise_pred_uncond, source_noise_pred_text, noise_pred_text, ) = concat_noise_pred.chunk(4, dim=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( source_noise_pred_text - source_noise_pred_uncond ) # Sample source_latents from the posterior distribution. prev_source_latents = posterior_sample( self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs ) # Compute noise. noise = compute_noise( self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs ) source_latents = prev_source_latents # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs ).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 9. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from packaging import version from PIL import Image from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import deprecate, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> from diffusers import FlaxStableDiffusionPipeline >>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16 ... ) >>> prompt = "a photo of an astronaut riding a horse on mars" >>> prng_seed = jax.random.PRNGKey(0) >>> num_inference_steps = 50 >>> num_samples = jax.device_count() >>> prompt = num_samples * [prompt] >>> prompt_ids = pipeline.prepare_inputs(prompt) # shard inputs and rng >>> params = replicate(params) >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) >>> prompt_ids = shard(prompt_ids) >>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) ``` """ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`FlaxCLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def _generate( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int, height: int, width: int, guidance_scale: float, latents: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.array] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) # Ensure model output will be `float32` before going into the scheduler guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) latents_shape = ( batch_size, self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, neg_prompt_ids: jnp.array = None, return_dict: bool = True, jit: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. latents (`jnp.array`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. tensor will ge generated by sampling using the supplied random `generator`. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] if jit: images = _p_generate( self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) else: images = self._generate( prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.asarray(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), static_broadcasted_argnums=(0, 4, 5, 6), ) def _p_generate( pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ): return pipe._generate( prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: This file is deprecated and will be removed in a future version. # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works from ...utils import deprecate from ..controlnet.pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline # noqa: F401 deprecate( "stable diffusion controlnet", "0.22.0", "Importing `FlaxStableDiffusionControlNetPipeline` from diffusers.pipelines.stable_diffusion.flax_pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import FlaxStableDiffusionControlNetPipeline` instead.", standard_warn=False, stacklevel=3, ) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from PIL import Image from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> import jax.numpy as jnp >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline >>> def create_key(seed=0): ... return jax.random.PRNGKey(seed) >>> rng = create_key(0) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_img = Image.open(BytesIO(response.content)).convert("RGB") >>> init_img = init_img.resize((768, 512)) >>> prompts = "A fantasy landscape, trending on artstation" >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", ... revision="flax", ... dtype=jnp.bfloat16, ... ) >>> num_samples = jax.device_count() >>> rng = jax.random.split(rng, jax.device_count()) >>> prompt_ids, processed_image = pipeline.prepare_inputs( ... prompt=[prompts] * num_samples, image=[init_img] * num_samples ... ) >>> p_params = replicate(params) >>> prompt_ids = shard(prompt_ids) >>> processed_image = shard(processed_image) >>> output = pipeline( ... prompt_ids=prompt_ids, ... image=processed_image, ... params=p_params, ... prng_seed=rng, ... strength=0.75, ... num_inference_steps=50, ... jit=True, ... height=512, ... width=768, ... ).images >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) ``` """ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): r""" Pipeline for image-to-image generation using Stable Diffusion. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`FlaxCLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warn( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if not isinstance(image, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(image, Image.Image): image = [image] processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids, processed_images def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def get_timestep_start(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) return t_start def _generate( self, prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, start_timestep: int, num_inference_steps: int, height: int, width: int, guidance_scale: float, noise: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.array] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) latents_shape = ( batch_size, self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if noise is None: noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if noise.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}") # Create init_latents init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) init_latents = self.vae.config.scaling_factor * init_latents def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape ) latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size) latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(start_timestep, num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, strength: float = 0.8, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.array] = 7.5, noise: jnp.array = None, neg_prompt_ids: jnp.array = None, return_dict: bool = True, jit: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt_ids (`jnp.array`): The prompt or prompts to guide the image generation. image (`jnp.array`): Array representing an image batch, that will be used as the starting point for the process. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. noise (`jnp.array`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. tensor will ge generated by sampling using the supplied random `generator`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] start_timestep = self.get_timestep_start(num_inference_steps, strength) if jit: images = _p_generate( self, prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ) else: images = self._generate( prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.asarray(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0), static_broadcasted_argnums=(0, 5, 6, 7, 8), ) def _p_generate( pipe, prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ): return pipe._generate( prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) def preprocess(image, dtype): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from packaging import version from PIL import Image from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> import PIL >>> import requests >>> from io import BytesIO >>> from diffusers import FlaxStableDiffusionInpaintPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" >>> init_image = download_image(img_url).resize((512, 512)) >>> mask_image = download_image(mask_url).resize((512, 512)) >>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained( ... "xvjiarui/stable-diffusion-2-inpainting" ... ) >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" >>> prng_seed = jax.random.PRNGKey(0) >>> num_inference_steps = 50 >>> num_samples = jax.device_count() >>> prompt = num_samples * [prompt] >>> init_image = num_samples * [init_image] >>> mask_image = num_samples * [mask_image] >>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs( ... prompt, init_image, mask_image ... ) # shard inputs and rng >>> params = replicate(params) >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) >>> prompt_ids = shard(prompt_ids) >>> processed_masked_images = shard(processed_masked_images) >>> processed_masks = shard(processed_masks) >>> images = pipeline( ... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True ... ).images >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) ``` """ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`FlaxCLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs( self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]], mask: Union[Image.Image, List[Image.Image]], ): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if not isinstance(image, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(image, Image.Image): image = [image] if not isinstance(mask, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(mask, Image.Image): mask = [mask] processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image]) processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask]) # processed_masks[processed_masks < 0.5] = 0 processed_masks = processed_masks.at[processed_masks < 0.5].set(0) # processed_masks[processed_masks >= 0.5] = 1 processed_masks = processed_masks.at[processed_masks >= 0.5].set(1) processed_masked_images = processed_images * (processed_masks < 0.5) text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids, processed_masked_images, processed_masks def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def _generate( self, prompt_ids: jnp.array, mask: jnp.array, masked_image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int, height: int, width: int, guidance_scale: float, latents: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.array] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) latents_shape = ( batch_size, self.vae.config.latent_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") prng_seed, mask_prng_seed = jax.random.split(prng_seed) masked_image_latent_dist = self.vae.apply( {"params": params["vae"]}, masked_image, method=self.vae.encode ).latent_dist masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2)) masked_image_latents = self.vae.config.scaling_factor * masked_image_latents del mask_prng_seed mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest") # 8. Check that sizes of mask, masked image and latents match num_channels_latents = self.vae.config.latent_channels num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) def loop_body(step, args): latents, mask, masked_image_latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) mask_input = jnp.concatenate([mask] * 2) masked_image_latents_input = jnp.concatenate([masked_image_latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) # concat latents, mask, masked_image_latents in the channel dimension latents_input = jnp.concatenate([latents_input, mask_input, masked_image_latents_input], axis=1) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, mask, masked_image_latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, mask, masked_image_latents, scheduler_state = loop_body( i, (latents, mask, masked_image_latents, scheduler_state) ) else: latents, _, _, _ = jax.lax.fori_loop( 0, num_inference_steps, loop_body, (latents, mask, masked_image_latents, scheduler_state) ) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, mask: jnp.array, masked_image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, neg_prompt_ids: jnp.array = None, return_dict: bool = True, jit: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. latents (`jnp.array`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. tensor will ge generated by sampling using the supplied random `generator`. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor masked_image = jax.image.resize(masked_image, (*masked_image.shape[:-2], height, width), method="bicubic") mask = jax.image.resize(mask, (*mask.shape[:-2], height, width), method="nearest") if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] if jit: images = _p_generate( self, prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) else: images = self._generate( prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.asarray(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), static_broadcasted_argnums=(0, 6, 7, 8), ) def _p_generate( pipe, prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ): return pipe._generate( prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) def preprocess_image(image, dtype): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 def preprocess_mask(mask, dtype): w, h = mask.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = mask.resize((w, h)) mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 mask = jnp.expand_dims(mask, axis=(0, 1)) return mask ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import numpy as np import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) class OnnxStableDiffusionPipeline(DiffusionPipeline): vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def check_inputs( self, prompt: Union[str, List[str]], height: Optional[int], width: Optional[int], callback_steps: int, negative_prompt: Optional[str] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): `Image`, or tensor representing an image batch which will be upscaled. * num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): One or a list of [numpy generator(s)](TODO) to make generation deterministic. latents (`np.ndarray`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # get the initial random noise unless the user supplied it latents_dtype = prompt_embeds.dtype latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) elif latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # set timesteps self.scheduler.set_timesteps(num_inference_steps) latents = latents * np.float64(self.scheduler.init_noise_sigma) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = latent_model_input.cpu().numpy() # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) noise_pred = noise_pred[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ) latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, ): deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) super().__init__( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 def preprocess(image): warnings.warn( ( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead" ), FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def check_inputs( self, prompt: Union[str, List[str]], callback_steps: int, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def __call__( self, prompt: Union[str, List[str]], image: Union[np.ndarray, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`np.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if generator is None: generator = np.random # set timesteps self.scheduler.set_timesteps(num_inference_steps) image = preprocess(image).cpu().numpy() # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype image = image.astype(latents_dtype) # encode the init image into latents and scale the latents init_latents = self.vae_encoder(sample=image)[0] init_latents = 0.18215 * init_latents if isinstance(prompt, str): prompt = [prompt] if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = len(prompt) // init_latents.shape[0] init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." ) else: init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps.numpy()[-init_timestep] timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) # add noise to latents using the timesteps noise = generator.randn(*init_latents.shape).astype(latents_dtype) init_latents = self.scheduler.add_noise( torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) ) init_latents = init_latents.numpy() # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:].numpy() timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = latent_model_input.cpu().numpy() # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ 0 ] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ) latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) # safety_checker does not support batched inputs yet images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name NUM_UNET_INPUT_CHANNELS = 9 NUM_LATENT_CHANNELS = 4 def prepare_mask_and_masked_image(image, mask, latents_shape): image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8))) image = image[None].transpose(0, 3, 1, 2) image = image.astype(np.float32) / 127.5 - 1.0 image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) masked_image = image * (image_mask < 127.5) mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) mask = np.array(mask.convert("L")) mask = mask.astype(np.float32) / 255.0 mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 return mask, masked_image class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs def check_inputs( self, prompt: Union[str, List[str]], height: Optional[int], width: Optional[int], callback_steps: int, negative_prompt: Optional[str] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], image: PIL.Image.Image, mask_image: PIL.Image.Image, height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. latents (`np.ndarray`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random # set timesteps self.scheduler.set_timesteps(num_inference_steps) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) num_channels_latents = NUM_LATENT_CHANNELS latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) latents_dtype = prompt_embeds.dtype if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # prepare mask and masked_image mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:]) mask = mask.astype(latents.dtype) masked_image = masked_image.astype(latents.dtype) masked_image_latents = self.vae_encoder(sample=masked_image)[0] masked_image_latents = 0.18215 * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt mask = mask.repeat(batch_size * num_images_per_prompt, 0) masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] unet_input_channels = NUM_UNET_INPUT_CHANNELS if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels: raise ValueError( "Incorrect configuration settings! The config of `pipeline.unet` expects" f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) # set timesteps self.scheduler.set_timesteps(num_inference_steps) # scale the initial noise by the standard deviation required by the scheduler latents = latents * np.float64(self.scheduler.init_noise_sigma) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latnets in the channel dimension latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = latent_model_input.cpu().numpy() latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ 0 ] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ) latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) # safety_checker does not support batched inputs yet images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py ================================================ import inspect from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = 1 - mask # repaint white, keep black return mask class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): r""" Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def __call__( self, prompt: Union[str, List[str]], image: Union[np.ndarray, PIL.Image.Image] = None, mask_image: Union[np.ndarray, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`nd.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. This is the image whose masked region will be inpainted. mask_image (`nd.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if generator is None: generator = np.random # set timesteps self.scheduler.set_timesteps(num_inference_steps) if isinstance(image, PIL.Image.Image): image = preprocess(image) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype image = image.astype(latents_dtype) # encode the init image into latents and scale the latents init_latents = self.vae_encoder(sample=image)[0] init_latents = 0.18215 * init_latents # Expand init_latents for batch_size and num_images_per_prompt init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) init_latents_orig = init_latents # preprocess mask if not isinstance(mask_image, np.ndarray): mask_image = preprocess_mask(mask_image, 8) mask_image = mask_image.astype(latents_dtype) mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) # check sizes if not mask.shape == init_latents.shape: raise ValueError("The mask and image should be the same size!") # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps.numpy()[-init_timestep] timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) # add noise to latents using the timesteps noise = generator.randn(*init_latents.shape).astype(latents_dtype) init_latents = self.scheduler.add_noise( torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) ) init_latents = init_latents.numpy() # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:].numpy() timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ 0 ] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ).prev_sample latents = latents.numpy() init_latents_proper = self.scheduler.add_noise( torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t])) ) init_latents_proper = init_latents_proper.numpy() latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) # There will throw an error if use safety_checker batchsize>1 images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py ================================================ from logging import getLogger from typing import Any, Callable, List, Optional, Union import numpy as np import PIL import torch from ...schedulers import DDPMScheduler from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import ImagePipelineOutput from . import StableDiffusionUpscalePipeline logger = getLogger(__name__) NUM_LATENT_CHANNELS = 4 NUM_UNET_INPUT_CHANNELS = 7 ORT_TO_PT_TYPE = { "float16": torch.float16, "float32": torch.float32, } def preprocess(image): if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 32 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): def __init__( self, vae: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: Any, unet: OnnxRuntimeModel, low_res_scheduler: DDPMScheduler, scheduler: Any, max_noise_level: int = 350, ): super().__init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, low_res_scheduler=low_res_scheduler, scheduler=scheduler, safety_checker=None, feature_extractor=None, watermarker=None, max_noise_level=max_noise_level, ) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`np.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. noise_level TODO negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs(prompt, image, noise_level, callback_steps) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_embeddings = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)] # 4. Preprocess image image = preprocess(image) image = image.cpu() # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Add noise to image noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype) image = self.low_res_scheduler.add_noise(image, noise, noise_level) batch_multiplier = 2 if do_classifier_free_guidance else 1 image = np.concatenate([image] * batch_multiplier * num_images_per_prompt) noise_level = np.concatenate([noise_level] * image.shape[0]) # 6. Prepare latent variables height, width = image.shape[2:] latents = self.prepare_latents( batch_size * num_images_per_prompt, NUM_LATENT_CHANNELS, height, width, latents_dtype, device, generator, latents, ) # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS: raise ValueError( "Incorrect configuration settings! The config of `pipeline.unet` expects" f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +" f" `num_channels_image`: {num_channels_image} " f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = np.concatenate([latent_model_input, image], axis=1) # timestep to tensor timestep = np.array([t], dtype=timestep_dtype) # predict the noise residual noise_pred = self.unet( sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings, class_labels=noise_level.astype(np.int64), )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs ).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 10. Post-processing image = self.decode_latents(latents.float()) # 11. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) def decode_latents(self, latents): latents = 1 / 0.08333 * latents image = self.vae(latent_sample=latents)[0] image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) return image def _encode_prompt( self, prompt: Union[str, List[str]], device, num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) # no positional arguments to text_encoder prompt_embeds = self.text_encoder( input_ids=text_input_ids.int().to(device), # attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) # if hasattr(uncond_input, "attention_mask"): # attention_mask = uncond_input.attention_mask.to(device) # else: # attention_mask = None uncond_embeddings = self.text_encoder( input_ids=uncond_input.input_ids.int().to(device), # attention_mask=attention_mask, ) uncond_embeddings = uncond_embeddings[0] if do_classifier_free_guidance: seq_len = uncond_embeddings.shape[1] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds]) return prompt_embeds ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPipeline >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import math import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from torch.nn import functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionAttendAndExcitePipeline >>> pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "a cat and a frog" >>> # use get_indices function to find out indices of the tokens you want to alter >>> pipe.get_indices(prompt) {0: '<|startoftext|>', 1: 'a', 2: 'cat', 3: 'and', 4: 'a', 5: 'frog', 6: '<|endoftext|>'} >>> token_indices = [2, 5] >>> seed = 6141 >>> generator = torch.Generator("cuda").manual_seed(seed) >>> images = pipe( ... prompt=prompt, ... token_indices=token_indices, ... guidance_scale=7.5, ... generator=generator, ... num_inference_steps=50, ... max_iter_to_alter=25, ... ).images >>> image = images[0] >>> image.save(f"../images/{prompt}_{seed}.png") ``` """ class AttentionStore: @staticmethod def get_empty_store(): return {"down": [], "mid": [], "up": []} def __call__(self, attn, is_cross: bool, place_in_unet: str): if self.cur_att_layer >= 0 and is_cross: if attn.shape[1] == np.prod(self.attn_res): self.step_store[place_in_unet].append(attn) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers: self.cur_att_layer = 0 self.between_steps() def between_steps(self): self.attention_store = self.step_store self.step_store = self.get_empty_store() def get_average_attention(self): average_attention = self.attention_store return average_attention def aggregate_attention(self, from_where: List[str]) -> torch.Tensor: """Aggregates the attention across the different layers and heads at the specified resolution.""" out = [] attention_maps = self.get_average_attention() for location in from_where: for item in attention_maps[location]: cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1]) out.append(cross_maps) out = torch.cat(out, dim=0) out = out.sum(0) / out.shape[0] return out def reset(self): self.cur_att_layer = 0 self.step_store = self.get_empty_store() self.attention_store = {} def __init__(self, attn_res): """ Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion process """ self.num_att_layers = -1 self.cur_att_layer = 0 self.step_store = self.get_empty_store() self.attention_store = {} self.curr_step_index = 0 self.attn_res = attn_res class AttendExciteAttnProcessor: def __init__(self, attnstore, place_in_unet): super().__init__() self.attnstore = attnstore self.place_in_unet = place_in_unet def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) is_cross = encoder_hidden_states is not None encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) # only need to store attention maps during the Attend and Excite process if attention_probs.requires_grad: self.attnstore(attention_probs, is_cross, self.place_in_unet) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, indices, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int) indices_is_list_list_ints = ( isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int) ) if not indices_is_list_ints and not indices_is_list_list_ints: raise TypeError("`indices` must be a list of ints or a list of a list of ints") if indices_is_list_ints: indices_batch_size = 1 elif indices_is_list_list_ints: indices_batch_size = len(indices) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if indices_batch_size != prompt_batch_size: raise ValueError( f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @staticmethod def _compute_max_attention_per_index( attention_maps: torch.Tensor, indices: List[int], ) -> List[torch.Tensor]: """Computes the maximum attention value for each of the tokens we wish to alter.""" attention_for_text = attention_maps[:, :, 1:-1] attention_for_text *= 100 attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) # Shift indices since we removed the first token indices = [index - 1 for index in indices] # Extract the maximum values max_indices_list = [] for i in indices: image = attention_for_text[:, :, i] smoothing = GaussianSmoothing().to(attention_maps.device) input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect") image = smoothing(input).squeeze(0).squeeze(0) max_indices_list.append(image.max()) return max_indices_list def _aggregate_and_get_max_attention_per_token( self, indices: List[int], ): """Aggregates the attention for each token and computes the max activation value for each token to alter.""" attention_maps = self.attention_store.aggregate_attention( from_where=("up", "down", "mid"), ) max_attention_per_index = self._compute_max_attention_per_index( attention_maps=attention_maps, indices=indices, ) return max_attention_per_index @staticmethod def _compute_loss(max_attention_per_index: List[torch.Tensor]) -> torch.Tensor: """Computes the attend-and-excite loss using the maximum attention value for each token.""" losses = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index] loss = max(losses) return loss @staticmethod def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: """Update the latent according to the computed loss.""" grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] latents = latents - step_size * grad_cond return latents def _perform_iterative_refinement_step( self, latents: torch.Tensor, indices: List[int], loss: torch.Tensor, threshold: float, text_embeddings: torch.Tensor, step_size: float, t: int, max_refinement_steps: int = 20, ): """ Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent code according to our loss objective until the given threshold is reached for all tokens. """ iteration = 0 target_loss = max(0, 1.0 - threshold) while loss > target_loss: iteration += 1 latents = latents.clone().detach().requires_grad_(True) self.unet(latents, t, encoder_hidden_states=text_embeddings).sample self.unet.zero_grad() # Get max activation value for each subject token max_attention_per_index = self._aggregate_and_get_max_attention_per_token( indices=indices, ) loss = self._compute_loss(max_attention_per_index) if loss != 0: latents = self._update_latent(latents, loss, step_size) logger.info(f"\t Try {iteration}. loss: {loss}") if iteration >= max_refinement_steps: logger.info(f"\t Exceeded max number of iterations ({max_refinement_steps})! ") break # Run one more time but don't compute gradients and update the latents. # We just need to compute the new loss - the grad update will occur below latents = latents.clone().detach().requires_grad_(True) _ = self.unet(latents, t, encoder_hidden_states=text_embeddings).sample self.unet.zero_grad() # Get max activation value for each subject token max_attention_per_index = self._aggregate_and_get_max_attention_per_token( indices=indices, ) loss = self._compute_loss(max_attention_per_index) logger.info(f"\t Finished with loss of: {loss}") return loss, latents, max_attention_per_index def register_attention_control(self): attn_procs = {} cross_att_count = 0 for name in self.unet.attn_processors.keys(): if name.startswith("mid_block"): place_in_unet = "mid" elif name.startswith("up_blocks"): place_in_unet = "up" elif name.startswith("down_blocks"): place_in_unet = "down" else: continue cross_att_count += 1 attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet) self.unet.set_attn_processor(attn_procs) self.attention_store.num_att_layers = cross_att_count def get_indices(self, prompt: str) -> Dict[str, int]: """Utility function to list the indices of the tokens you wish to alte""" ids = self.tokenizer(prompt).input_ids indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))} return indices @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], token_indices: Union[List[int], List[List[int]]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, max_iter_to_alter: int = 25, thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8}, scale_factor: int = 20, attn_res: Optional[Tuple[int]] = (16, 16), ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. token_indices (`List[int]`): The token indices to alter with attend-and-excite. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). max_iter_to_alter (`int`, *optional*, defaults to `25`): Number of denoising steps to apply attend-and-excite. The first denoising steps are where the attend-and-excite is applied. I.e. if `max_iter_to_alter` is 25 and there are a total of `30` denoising steps, the first 25 denoising steps will apply attend-and-excite and the last 5 will not apply attend-and-excite. thresholds (`dict`, *optional*, defaults to `{0: 0.05, 10: 0.5, 20: 0.8}`): Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in. scale_factor (`int`, *optional*, default to 20): Scale factor that controls the step size of each Attend and Excite update. attn_res (`tuple`, *optional*, default computed from width and height): The 2D resolution of the semantic attention map. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. :type attention_store: object """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, token_indices, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) if attn_res is None: attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) self.attention_store = AttentionStore(attn_res) self.register_attention_control() # default config for step size from original repo scale_range = np.linspace(1.0, 0.5, len(self.scheduler.timesteps)) step_size = scale_factor * np.sqrt(scale_range) text_embeddings = ( prompt_embeds[batch_size * num_images_per_prompt :] if do_classifier_free_guidance else prompt_embeds ) if isinstance(token_indices[0], int): token_indices = [token_indices] indices = [] for ind in token_indices: indices = indices + [ind] * num_images_per_prompt # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Attend and excite process with torch.enable_grad(): latents = latents.clone().detach().requires_grad_(True) updated_latents = [] for latent, index, text_embedding in zip(latents, indices, text_embeddings): # Forward pass of denoising with text conditioning latent = latent.unsqueeze(0) text_embedding = text_embedding.unsqueeze(0) self.unet( latent, t, encoder_hidden_states=text_embedding, cross_attention_kwargs=cross_attention_kwargs, ).sample self.unet.zero_grad() # Get max activation value for each subject token max_attention_per_index = self._aggregate_and_get_max_attention_per_token( indices=index, ) loss = self._compute_loss(max_attention_per_index=max_attention_per_index) # If this is an iterative refinement step, verify we have reached the desired threshold for all if i in thresholds.keys() and loss > 1.0 - thresholds[i]: loss, latent, max_attention_per_index = self._perform_iterative_refinement_step( latents=latent, indices=index, loss=loss, threshold=thresholds[i], text_embeddings=text_embedding, step_size=step_size[i], t=t, ) # Perform gradient update if i < max_iter_to_alter: if loss != 0: latent = self._update_latent( latents=latent, loss=loss, step_size=step_size[i], ) logger.info(f"Iteration {i} | Loss: {loss:0.4f}") updated_latents.append(latent) latents = torch.cat(updated_latents, dim=0) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) class GaussianSmoothing(torch.nn.Module): """ Arguments: Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel. dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial). """ # channels=1, kernel_size=kernel_size, sigma=sigma, dim=2 def __init__( self, channels: int = 1, kernel_size: int = 3, sigma: float = 0.5, dim: int = 2, ): super().__init__() if isinstance(kernel_size, int): kernel_size = [kernel_size] * dim if isinstance(sigma, float): sigma = [sigma] * dim # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer("weight", kernel) self.groups = channels if dim == 1: self.conv = F.conv1d elif dim == 2: self.conv = F.conv2d elif dim == 3: self.conv = F.conv3d else: raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) def forward(self, input): """ Arguments: Apply gaussian filter to input. input (torch.Tensor): Input to apply gaussian filter on. Returns: filtered (torch.Tensor): Filtered output. """ return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: This file is deprecated and will be removed in a future version. # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works from ...utils import deprecate from ..controlnet.multicontrolnet import MultiControlNetModel # noqa: F401 from ..controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline # noqa: F401 deprecate( "stable diffusion controlnet", "0.22.0", "Importing `StableDiffusionControlNetPipeline` or `MultiControlNetModel` from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import StableDiffusionControlNetPipeline` instead.", standard_warn=False, stacklevel=3, ) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, depth_estimator: DPTForDepthEstimation, feature_extractor: DPTFeatureExtractor, ): super().__init__() is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, depth_estimator=depth_estimator, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.depth_estimator]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): if isinstance(image, PIL.Image.Image): image = [image] else: image = list(image) if isinstance(image[0], PIL.Image.Image): width, height = image[0].size elif isinstance(image[0], np.ndarray): width, height = image[0].shape[:-1] else: height, width = image[0].shape[-2:] if depth_map is None: pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device=device) # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. # So we use `torch.autocast` here for half precision inference. context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext() with context_manger: depth_map = self.depth_estimator(pixel_values).predicted_depth else: depth_map = depth_map.to(device=device, dtype=dtype) depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1), size=(height // self.vae_scale_factor, width // self.vae_scale_factor), mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 depth_map = depth_map.to(dtype) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if depth_map.shape[0] < batch_size: repeat_by = batch_size // depth_map.shape[0] depth_map = depth_map.repeat(repeat_by, 1, 1, 1) depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map return depth_map @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, depth_map: Optional[torch.FloatTensor] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can accept image latents as `image` only if `depth_map` is not `None`. depth_map (`torch.FloatTensor`, *optional*): depth prediction that will be used as additional conditioning for the image generation process. If not defined, it will automatically predicts the depth via `self.depth_estimator`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: ```py >>> import torch >>> import requests >>> from PIL import Image >>> from diffusers import StableDiffusionDepth2ImgPipeline >>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-depth", ... torch_dtype=torch.float16, ... ) >>> pipe.to("cuda") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> init_image = Image.open(requests.get(url, stream=True).raw) >>> prompt = "two tigers" >>> n_propmt = "bad, deformed, ugly, bad anotomy" >>> image = pipe(prompt=prompt, image=init_image, negative_prompt=n_propmt, strength=0.7).images[0] ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs( prompt, strength, callback_steps, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare depth mask depth_mask = self.prepare_depth_map( image, depth_map, batch_size * num_images_per_prompt, do_classifier_free_guidance, prompt_embeds.dtype, device, ) # 5. Preprocess image image = self.image_processor.preprocess(image) # 6. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 7. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py ================================================ # Copyright 2023 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, BaseOutput, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class DiffEditInversionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: latents (`torch.FloatTensor`) inverted latents tensor images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps, batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ latents: torch.FloatTensor images: Union[List[PIL.Image.Image], np.ndarray] EXAMPLE_DOC_STRING = """ ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionDiffEditPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" >>> init_image = download_image(img_url).resize((768, 768)) >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) >>> pipeline.enable_model_cpu_offload() >>> mask_prompt = "A bowl of fruits" >>> prompt = "A bowl of pears" >>> mask_image = pipe.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt).latents >>> image = pipe(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] ``` """ EXAMPLE_INVERT_DOC_STRING = """ ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionDiffEditPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" >>> init_image = download_image(img_url).resize((768, 768)) >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) >>> pipeline.enable_model_cpu_offload() >>> prompt = "A bowl of fruits" >>> inverted_latents = pipe.invert(image=init_image, prompt=prompt).latents ``` """ def auto_corr_loss(hidden_states, generator=None): reg_loss = 0.0 for i in range(hidden_states.shape[0]): for j in range(hidden_states.shape[1]): noise = hidden_states[i : i + 1, j : j + 1, :, :] while True: roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 if noise.shape[2] <= 8: break noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2) return reg_loss def kl_divergence(hidden_states): return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def preprocess_mask(mask, batch_size: int = 1): if not isinstance(mask, torch.Tensor): # preprocess mask if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray): mask = [mask] if isinstance(mask, list): if isinstance(mask[0], PIL.Image.Image): mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] if isinstance(mask[0], np.ndarray): mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0) mask = torch.from_numpy(mask) elif isinstance(mask[0], torch.Tensor): mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) # Check mask shape if batch_size > 1: if mask.shape[0] == 1: mask = torch.cat([mask] * batch_size) elif mask.shape[0] > 1 and mask.shape[0] != batch_size: raise ValueError( f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} " f"inferred by prompt inputs" ) if mask.shape[1] != 1: raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("`mask_image` should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 return mask class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image inpainting using Stable Diffusion using DiffEdit. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. inverse_scheduler (`[DDIMInverseScheduler]`): A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, inverse_scheduler: DDIMInverseScheduler, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration" " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" " Hub, it would be very nice if you could open a Pull request for the" " `scheduler/scheduler_config.json` file" ) deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["skip_prk_steps"] = True scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, inverse_scheduler=inverse_scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): raise ValueError( f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def check_source_inputs( self, source_prompt=None, source_negative_prompt=None, source_prompt_embeds=None, source_negative_prompt_embeds=None, ): if source_prompt is not None and source_prompt_embeds is not None: raise ValueError( f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}." " Please make sure to only forward one of the two." ) elif source_prompt is None and source_prompt_embeds is None: raise ValueError( "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined." ) elif source_prompt is not None and ( not isinstance(source_prompt, str) and not isinstance(source_prompt, list) ): raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}") if source_negative_prompt is not None and source_negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:" f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two." ) if source_prompt_embeds is not None and source_negative_prompt_embeds is not None: if source_prompt_embeds.shape != source_negative_prompt_embeds.shape: raise ValueError( "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed" f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !=" f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def get_inverse_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) # safety for t_start overflow to prevent empty timsteps slice if t_start == 0: return self.inverse_scheduler.timesteps, num_inference_steps timesteps = self.inverse_scheduler.timesteps[:-t_start] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.StableDiffusionPix2PixZeroPipeline.prepare_image_latents def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) if image.shape[1] == 4: latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] latents = torch.cat(latents, dim=0) else: latents = self.vae.encode(image).latent_dist.sample(generator) latents = self.vae.config.scaling_factor * latents if batch_size != latents.shape[0]: if batch_size % latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_latents_per_image = batch_size // latents.shape[0] latents = torch.cat([latents] * additional_latents_per_image, dim=0) else: raise ValueError( f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." ) else: latents = torch.cat([latents], dim=0) return latents def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): pred_type = self.inverse_scheduler.config.prediction_type alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if pred_type == "epsilon": return model_output elif pred_type == "sample": return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) elif pred_type == "v_prediction": return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" ) @torch.no_grad() def generate_mask( self, image: Union[torch.FloatTensor, PIL.Image.Image] = None, target_prompt: Optional[Union[str, List[str]]] = None, target_negative_prompt: Optional[Union[str, List[str]]] = None, target_prompt_embeds: Optional[torch.FloatTensor] = None, target_negative_prompt_embeds: Optional[torch.FloatTensor] = None, source_prompt: Optional[Union[str, List[str]]] = None, source_negative_prompt: Optional[Union[str, List[str]]] = None, source_prompt_embeds: Optional[torch.FloatTensor] = None, source_negative_prompt_embeds: Optional[torch.FloatTensor] = None, num_maps_per_mask: Optional[int] = 10, mask_encode_strength: Optional[float] = 0.5, mask_thresholding_ratio: Optional[float] = 3.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "np", cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function used to generate a latent mask given a mask prompt, a target prompt, and an image. Args: image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be used for computing the mask. target_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the semantic mask generation. If not defined, one has to pass `prompt_embeds`. instead. target_negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). target_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. target_negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. source_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the semantic mask generation using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If not defined, one has to pass `source_prompt_embeds` or `source_image` instead. source_negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the semantic mask generation away from using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If not defined, one has to pass `source_negative_prompt_embeds` or `source_image` instead. source_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `source_prompt` input argument. source_negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `source_negative_prompt` input argument. num_maps_per_mask (`int`, *optional*, defaults to 10): The number of noise maps sampled to generate the semantic mask using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). mask_encode_strength (`float`, *optional*, defaults to 0.5): Conceptually, the strength of the noise maps sampled to generate the semantic mask using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance]( https://arxiv.org/pdf/2210.11427.pdf). Must be between 0 and 1. mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before mask binarization. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: `List[PIL.Image.Image]` or `np.array`: `List[PIL.Image.Image]` if `output_type` is `"pil"`, otherwise a `np.array`. When returning a `List[PIL.Image.Image]`, the list will consist of a batch of single-channel binary image with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`, otherwise the `np.array` will have shape `(batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)`. """ # 1. Check inputs (Provide dummy argument for callback_steps) self.check_inputs( target_prompt, mask_encode_strength, 1, target_negative_prompt, target_prompt_embeds, target_negative_prompt_embeds, ) self.check_source_inputs( source_prompt, source_negative_prompt, source_prompt_embeds, source_negative_prompt_embeds, ) if (num_maps_per_mask is None) or ( num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0) ): raise ValueError( f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type" f" {type(num_maps_per_mask)}." ) if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0: raise ValueError( f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type" f" {type(mask_thresholding_ratio)}." ) # 2. Define call parameters if target_prompt is not None and isinstance(target_prompt, str): batch_size = 1 elif target_prompt is not None and isinstance(target_prompt, list): batch_size = len(target_prompt) else: batch_size = target_prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompts (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) target_prompt_embeds = self._encode_prompt( target_prompt, device, num_maps_per_mask, do_classifier_free_guidance, target_negative_prompt, prompt_embeds=target_prompt_embeds, negative_prompt_embeds=target_negative_prompt_embeds, ) source_prompt_embeds = self._encode_prompt( source_prompt, device, num_maps_per_mask, do_classifier_free_guidance, source_negative_prompt, prompt_embeds=source_prompt_embeds, negative_prompt_embeds=source_negative_prompt_embeds, ) # 4. Preprocess image image = preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength, device) encode_timestep = timesteps[0] # 6. Prepare image latents and add noise with specified strength image_latents = self.prepare_image_latents( image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator ) noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=self.vae.dtype) image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) # 7. Predict the noise residual prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds]) noise_pred = self.unet( latent_model_input, encode_timestep, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample if do_classifier_free_guidance: noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src) noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) else: noise_pred_source, noise_pred_target = noise_pred.chunk(2) # 8. Compute the mask from the absolute difference of predicted noise residuals # TODO: Consider smoothing mask guidance map mask_guidance_map = ( torch.abs(noise_pred_target - noise_pred_source) .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:]) .mean([1, 2]) ) clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1) mask_image = semantic_mask_image.cpu().numpy() # 9. Convert to Numpy array or PIL. if output_type == "pil": mask_image = self.image_processor.numpy_to_pil(mask_image) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() return mask_image @torch.no_grad() @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) def invert( self, prompt: Optional[Union[str, List[str]]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, num_inference_steps: int = 50, inpaint_strength: float = 0.8, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, decode_latents: bool = False, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, lambda_auto_corr: float = 20.0, lambda_kl: float = 20.0, num_reg_steps: int = 0, num_auto_corr_rolls: int = 5, ): r""" Function used to generate inverted latents given a prompt and image. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch to produce the inverted latents, guided by `prompt`. inpaint_strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how far into the noising process to run latent inversion. Must be between 0 and 1. When `strength` is 1, the inversion process will be run for the full number of iterations specified in `num_inference_steps`. `image` will be used as a reference for the inversion process, adding more noise the larger the `strength`. If `strength` is 0, no inpainting will occur. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. decode_latents (`bool`, *optional*, defaults to `False`): Whether or not to decode the inverted latents into a generated image. Setting this argument to `True` will decode all inverted latents for each timestep into a list of generated images. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). lambda_auto_corr (`float`, *optional*, defaults to 20.0): Lambda parameter to control auto correction lambda_kl (`float`, *optional*, defaults to 20.0): Lambda parameter to control Kullback–Leibler divergence output num_reg_steps (`int`, *optional*, defaults to 0): Number of regularization loss steps num_auto_corr_rolls (`int`, *optional*, defaults to 5): Number of auto correction roll steps Examples: Returns: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] if `return_dict` is `True`, otherwise a `tuple`. When returning a tuple, the first element is the inverted latents tensors ordered by increasing noise, and then second is the corresponding decoded images if `decode_latents` is `True`, otherwise `None`. """ # 1. Check inputs self.check_inputs( prompt, inpaint_strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image image = preprocess(image) # 4. Prepare latent variables num_images_per_prompt = 1 latents = self.prepare_image_latents( image, batch_size * num_images_per_prompt, self.vae.dtype, device, generator ) # 5. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 6. Prepare timesteps self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength, device) # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order inverted_latents = [latents.detach().clone()] with self.progress_bar(total=num_inference_steps - 1) as progress_bar: for i, t in enumerate(timesteps[:-1]): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) if num_reg_steps > 0: with torch.enable_grad(): for _ in range(num_reg_steps): if lambda_auto_corr > 0: for _ in range(num_auto_corr_rolls): var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_ac = auto_corr_loss(var_epsilon, generator=generator) l_ac.backward() grad = var.grad.detach() / num_auto_corr_rolls noise_pred = noise_pred - lambda_auto_corr * grad if lambda_kl > 0: var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_kld = kl_divergence(var_epsilon) l_kld.backward() grad = var.grad.detach() noise_pred = noise_pred - lambda_kl * grad noise_pred = noise_pred.detach() # compute the previous noisy sample x_t -> x_t-1 latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample inverted_latents.append(latents.detach().clone()) # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) assert len(inverted_latents) == len(timesteps) latents = torch.stack(list(reversed(inverted_latents)), 1) # 8. Post-processing image = None if decode_latents: image = self.decode_latents(latents.flatten(0, 1).detach()) # 9. Convert to PIL. if decode_latents and output_type == "pil": image = self.image_processor.numpy_to_pil(image) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (latents, image) return DiffEditInversionPipelineOutput(latents=latents, images=image) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, image_latents: torch.FloatTensor = None, inpaint_strength: Optional[float] = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask the generated image. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, 1, H, W)`. image_latents (`PIL.Image.Image` or `torch.FloatTensor`): Partially noised image latents from the inversion process to be used as inputs for image generation. inpaint_strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` is 1, the denoising process will be run on the masked area for the full number of iterations specified in `num_inference_steps`. `image_latents` will be used as a reference for the masked area, adding more noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs( prompt, inpaint_strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) if mask_image is None: raise ValueError( "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts." ) if image_latents is None: raise ValueError( "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images." ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess mask mask_image = preprocess_mask(mask_image, batch_size) latent_height, latent_width = mask_image.shape[-2:] mask_image = torch.cat([mask_image] * num_images_per_prompt) mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) # 6. Preprocess image latents image_latents = preprocess(image_latents) latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) if image_latents.shape[-3:] != latent_shape: raise ValueError( f"Each latent image in `image_latents` must have shape {latent_shape}, " f"but has shape {image_latents.shape[-3:]}" ) if image_latents.ndim == 4: image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) if image_latents.shape[:2] != (batch_size, len(timesteps)): raise ValueError( f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)} timesteps, " f"but has batch size {image_latents.shape[0]} with latent images from {image_latents.shape[1]} timesteps." ) image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1) image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop latents = image_latents[0].detach().clone() num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # mask with inverted latents from appropriate timestep - use original image latent for last step latents = latents * mask_image + image_latents[i] * (1 - mask_image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableDiffusionImageVariationPipeline(DiffusionPipeline): r""" Pipeline to generate variations from an input image using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ # TODO: feature_extractor is required to encode images (if they are in PIL format), # we should give a descriptive message if the pipeline doesn't have one. _optional_components = ["safety_checker"] def __init__( self, vae: AutoencoderKL, image_encoder: CLIPVisionModelWithProjection, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warn( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.unsqueeze(1) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeddings) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): The image or images to guide the image generation. If you provide a tensor, it needs to comply with the configuration of [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) `CLIPImageProcessor` height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) # 2. Define call parameters if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) else: batch_size = image.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input image image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import StableDiffusionImg2ImgPipeline >>> device = "cuda" >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) >>> prompt = "A fantasy landscape, trending on artstation" >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images >>> images[0].save("fantasy_landscape.png") ``` """ def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionImg2ImgPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if image is None: raise ValueError("`image` input cannot be undefined.") if mask is None: raise ValueError("`mask_image` input cannot be undefined.") if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) masked_image = image * (mask < 0.5) # n.b. ensure backwards compatibility as old function does not return image if return_image: return mask, masked_image, image return mask, masked_image class StableDiffusionInpaintPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). Default text-to-image stable diffusion checkpoints, such as [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) are also compatible with this pipeline, but might be less performant. Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration" " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" " Hub, it would be very nice if you could open a Pull request for the" " `scheduler/scheduler_config.json` file" ) deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["skip_prk_steps"] = True scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 if unet.config.in_channels != 9: logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, height, width, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None, timestep=None, is_strength_max=True, return_noise=False, return_image_latents=False, ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if (image is None or timestep is None) and not is_strength_max: raise ValueError( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) if return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents else: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma outputs = (latents,) if return_noise: outputs += (noise,) if return_image_latents: outputs += (image_latents,) return outputs def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = self.vae.config.scaling_factor * image_latents return image_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. strength (`float`, *optional*, defaults to 1.): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked portion of the reference `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionInpaintPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" >>> init_image = download_image(img_url).resize((512, 512)) >>> mask_image = download_image(mask_url).resize((512, 512)) >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs self.check_inputs( prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device ) # check that number of inference steps is not < 1 - as this doesn't make sense if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 # 5. Preprocess mask and image mask, masked_image, init_image = prepare_mask_and_masked_image( image, mask_image, height, width, return_image=True ) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, do_classifier_free_guidance, ) init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) init_image = self._encode_vae_image(init_image, generator=generator) # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: raise ValueError( f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if num_channels_unet == 9: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: init_latents_proper = image_latents[:1] init_mask = mask[:1] if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) def preprocess_image(image, batch_size): w, h = image.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) image = torch.from_numpy(image) return 2.0 * image - 1.0 def preprocess_mask(mask, batch_size, scale_factor=8): if not isinstance(mask, torch.FloatTensor): mask = mask.convert("L") w, h = mask.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = np.vstack([mask[None]] * batch_size) mask = 1 - mask # repaint white, keep black mask = torch.from_numpy(mask) return mask else: valid_mask_channel_sizes = [1, 3] # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) if mask.shape[3] in valid_mask_channel_sizes: mask = mask.permute(0, 3, 1, 2) elif mask.shape[1] not in valid_mask_channel_sizes: raise ValueError( f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," f" but received mask of shape {tuple(mask.shape)}" ) # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape mask = mask.mean(dim=1, keepdim=True) h, w = mask.shape[-2:] h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) return mask class StableDiffusionInpaintPipelineLegacy( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() deprecation_message = ( f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality" "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533" "for more information." ) deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False) if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator): image = image.to(device=device, dtype=dtype) init_latent_dist = self.vae.encode(image).latent_dist init_latents = init_latent_dist.sample(generator=generator) init_latents = self.vae.config.scaling_factor * init_latents # Expand init_latents for batch_size and num_images_per_prompt init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) init_latents_orig = init_latents # add noise to latents using the timesteps noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents, init_latents_orig, noise @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, add_predicted_noise: Optional[bool] = False, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. This is the image whose masked region will be inpainted. mask_image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` is 1, the denoising process will be run on the masked area for the full number of iterations specified in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. num_inference_steps (`int`, *optional*, defaults to 50): The reference number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`, as explained above. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. add_predicted_noise (`bool`, *optional*, defaults to True): Use predicted noise instead of random noise when constructing noisy versions of the original image in the reverse diffusion process eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image and mask if not isinstance(image, torch.FloatTensor): image = preprocess_image(image, batch_size) mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables # encode the init image into latents and scale the latents latents, init_latents_orig, noise = self.prepare_latents( image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare mask latent mask = mask_image.to(device=device, dtype=latents.dtype) mask = torch.cat([mask] * num_images_per_prompt) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # masking if add_predicted_noise: init_latents_proper = self.scheduler.add_noise( init_latents_orig, noise_pred_uncond, torch.tensor([t]) ) else: init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # use original latents corresponding to unmasked portions of the image latents = (init_latents_orig * mask) + (latents * (1 - mask)) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py ================================================ # Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, num_inference_steps: int = 100, guidance_scale: float = 7.5, image_guidance_scale: float = 1.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. This pipeline requires a value of at least `1`. image_guidance_scale (`float`, *optional*, defaults to 1.5): Image guidance scale is to push the generated image towards the inital image `image`. Image guidance scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to generate images that are closely linked to the source image `image`, usually at the expense of lower image quality. This pipeline requires a value of at least `1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionInstructPix2PixPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" >>> image = download_image(img_url).resize((512, 512)) >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "make the mountains snowy" >>> image = pipe(prompt=prompt, image=image).images[0] ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Check inputs self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) if image is None: raise ValueError("`image` input cannot be undefined.") # 1. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 # check if scheduler is in sigmas space scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") # 2. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 3. Preprocess image image = self.image_processor.preprocess(image) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare Image latents image_latents = self.prepare_image_latents( image, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, do_classifier_free_guidance, generator, ) height, width = image_latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Check that shapes of latents and image match the UNet channels num_channels_image = image_latents.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance\ # is applied for both the text and the input image. latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual noise_pred = self.unet( scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False )[0] # Hack: # For karras style schedulers the model does classifer free guidance using the # predicted_original_sample instead of the noise_pred. So we need to compute the # predicted_original_sample here if we are using a karras style scheduler. if scheduler_is_in_sigma_space: step_index = (self.scheduler.timesteps == t).nonzero()[0].item() sigma = self.scheduler.sigmas[step_index] noise_pred = latent_model_input - sigma * noise_pred # perform guidance if do_classifier_free_guidance: noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) noise_pred = ( noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_image) + image_guidance_scale * (noise_pred_image - noise_pred_uncond) ) # Hack: # For karras style schedulers the model does classifer free guidance using the # predicted_original_sample instead of the noise_pred. But the scheduler.step function # expects the noise_pred and computes the predicted_original_sample internally. So we # need to overwrite the noise_pred here such that the value of the computed # predicted_original_sample is correct. if scheduler_is_in_sigma_space: noise_pred = (noise_pred - latents) / (-sigma) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_ prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_image_latents( self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: image_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.mode() if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if do_classifier_free_guidance: uncond_image_latents = torch.zeros_like(image_latents) image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) return image_latents ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import warnings from typing import Callable, List, Optional, Union import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.sampling import get_sigmas_karras from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...pipelines import DiffusionPipeline from ...schedulers import LMSDiscreteScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class ModelWrapper: def __init__(self, model, alphas_cumprod): self.model = model self.alphas_cumprod = alphas_cumprod def apply_model(self, *args, **kwargs): if len(args) == 3: encoder_hidden_states = args[-1] args = args[:2] if kwargs.get("cond", None) is not None: encoder_hidden_states = kwargs.pop("cond") return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) This is an experimental pipeline and is likely to change in the future. Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker: bool = True, ): super().__init__() logger.info( f"{self.__class__} is an experimntal pipeline and is likely to change in the future. We recommend to use" " this pipeline for fast experimentation / iteration if needed, but advice to rely on existing pipelines" " as defined in https://huggingface.co/docs/diffusers/api/schedulers#implemented-schedulers for" " production settings." ) # get correct sigmas from LMS scheduler = LMSDiscreteScheduler.from_config(scheduler.config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) model = ModelWrapper(unet, scheduler.alphas_cumprod) if scheduler.config.prediction_type == "v_prediction": self.k_diffusion_model = CompVisVDenoiser(model) else: self.k_diffusion_model = CompVisDenoiser(model) def set_scheduler(self, scheduler_type: str): library = importlib.import_module("k_diffusion") sampling = getattr(library, "sampling") self.sampler = getattr(sampling, scheduler_type) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, use_karras_sigmas: Optional[bool] = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M Karras`. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = True if guidance_scale <= 1.0: raise ValueError("has to use guidance_scale") # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) # 5. Prepare sigmas if use_karras_sigmas: sigma_min: float = self.k_diffusion_model.sigmas[0].item() sigma_max: float = self.k_diffusion_model.sigmas[-1].item() sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) sigmas = sigmas.to(device) else: sigmas = self.scheduler.sigmas sigmas = sigmas.to(prompt_embeds.dtype) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) latents = latents * sigmas[0] self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) # 7. Define model function def model_fn(x, t): latent_model_input = torch.cat([x] * 2) t = torch.cat([t] * 2) noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) return noise_pred # 8. Run k-diffusion solver latents = self.sampler(model_fn, latents, sigmas) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): r""" Pipeline to upscale the resolution of Stable Diffusion output images by a factor of 2. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`EulerDiscreteScheduler`]. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: EulerDiscreteScheduler, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `list(int)`): prompt to be encoded device: (`torch.device`): torch device do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_length=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_encoder_out = self.text_encoder( text_input_ids.to(device), output_hidden_states=True, ) text_embeddings = text_encoder_out.hidden_states[-1] text_pooler_out = text_encoder_out.pooler_output # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_length=True, return_tensors="pt", ) uncond_encoder_out = self.text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) uncond_embeddings = uncond_encoder_out.hidden_states[-1] uncond_pooler_out = uncond_encoder_out.pooler_output # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out]) return text_embeddings, text_pooler_out # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs(self, prompt, image, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" ) # verify batch size of prompt and image are same if image is a list or tensor if isinstance(image, list) or isinstance(image, torch.Tensor): if isinstance(prompt, str): batch_size = 1 else: batch_size = len(prompt) if isinstance(image, list): image_batch_size = len(image) else: image_batch_size = image.shape[0] if image.ndim == 4 else 1 if batch_size != image_batch_size: raise ValueError( f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." " Please make sure that passed `prompt` matches the batch size of `image`." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height, width) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image upscaling. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an image representation and encoded using this pipeline's `vae` encoder. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline >>> import torch >>> pipeline = StableDiffusionPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... ) >>> pipeline.to("cuda") >>> model_id = "stabilityai/sd-x2-latent-upscaler" >>> upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16) >>> upscaler.to("cuda") >>> prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic" >>> generator = torch.manual_seed(33) >>> low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images >>> with torch.no_grad(): ... image = pipeline.decode_latents(low_res_latents) >>> image = pipeline.numpy_to_pil(image)[0] >>> image.save("../images/a1.png") >>> upscaled_image = upscaler( ... prompt=prompt, ... image=low_res_latents, ... num_inference_steps=20, ... guidance_scale=0, ... generator=generator, ... ).images[0] >>> upscaled_image.save("../images/a2.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs(prompt, image, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if guidance_scale == 0: prompt = [""] * batch_size # 3. Encode input prompt text_embeddings, text_pooler_out = self._encode_prompt( prompt, device, do_classifier_free_guidance, negative_prompt ) # 4. Preprocess image image = self.image_processor.preprocess(image) image = image.to(dtype=text_embeddings.dtype, device=device) if image.shape[1] == 3: # encode image if not in latent-space yet image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps batch_multiplier = 2 if do_classifier_free_guidance else 1 image = image[None, :] if image.ndim == 3 else image image = torch.cat([image] * batch_multiplier) # 5. Add noise to image (set to be 0): # (see below notes from the author): # "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default." noise_level = torch.tensor([0.0], dtype=torch.float32, device=device) noise_level = torch.cat([noise_level] * image.shape[0]) inv_noise_level = (noise_level**2 + 1) ** (-0.5) image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None] image_cond = image_cond.to(text_embeddings.dtype) noise_level_embed = torch.cat( [ torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), ], dim=1, ) timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1) # 6. Prepare latent variables height, width = image.shape[2:] num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size, num_channels_latents, height * 2, # 2x upscale width * 2, text_embeddings.dtype, device, generator, latents, ) # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 9. Denoising loop num_warmup_steps = 0 with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): sigma = self.scheduler.sigmas[i] # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_model_input = torch.cat([scaled_model_input, image_cond], dim=1) # preconditioning parameter based on Karras et al. (2022) (table 1) timestep = torch.log(sigma) * 0.25 noise_pred = self.unet( scaled_model_input, timestep, encoder_hidden_states=text_embeddings, timestep_cond=timestep_condition, ).sample # in original repo, the output contains a variance channel that's not used noise_pred = noise_pred[:, :-1] # apply preconditioning, based on table 1 in Karras et al. (2022) inv_sigma = 1 / (sigma**2 + 1) noise_pred = inv_sigma * latent_model_input + self.scheduler.scale_model_input(sigma, t) * noise_pred # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py ================================================ # Copyright 2023 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessorLDM3D from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BaseOutput, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPipeline >>> pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d") >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> output = pipe(prompt) >>> rgb_image, depth_image = output.rgb, output.depth ``` """ @dataclass class LDM3DPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. """ rgb: Union[List[PIL.Image.Image], np.ndarray] depth: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] class StableDiffusionLDM3DPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image and 3d generation using LDM3D. LDM3D: Latent Diffusion Model for 3D: https://arxiv.org/abs/2305.10853 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode rgb and depth images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded rgb and depth latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) rgb_feature_extractor_input = feature_extractor_input[0] safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 49, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return ((rgb, depth), has_nsfw_concept) return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py ================================================ # Copyright 2023 TIME Authors and The HuggingFace Team. All rights reserved." # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import PNDMScheduler from ...schedulers.scheduling_utils import SchedulerMixin from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name AUGS_CONST = ["A photo of ", "An image of ", "A picture of "] EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionModelEditingPipeline >>> model_ckpt = "CompVis/stable-diffusion-v1-4" >>> pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt) >>> pipe = pipe.to("cuda") >>> source_prompt = "A pack of roses" >>> destination_prompt = "A pack of blue roses" >>> pipe.edit_model(source_prompt, destination_prompt) >>> prompt = "A field of roses" >>> image = pipe(prompt).images[0] ``` """ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models". This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. with_to_k ([`bool`]): Whether to edit the key projection matrices along wiht the value projection matrices. with_augs ([`list`]): Textual augmentations to apply while editing the text-to-image model. Set to [] for no augmentations. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, with_to_k: bool = True, with_augs: list = AUGS_CONST, ): super().__init__() if isinstance(scheduler, PNDMScheduler): logger.error("PNDMScheduler for this pipeline is currently not supported.") if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.with_to_k = with_to_k self.with_augs = with_augs # get cross-attention layers ca_layers = [] def append_ca(net_): if net_.__class__.__name__ == "CrossAttention": ca_layers.append(net_) elif hasattr(net_, "children"): for net__ in net_.children(): append_ca(net__) # recursively find all cross-attention layers in unet for net in self.unet.named_children(): if "down" in net[0]: append_ca(net[1]) elif "up" in net[0]: append_ca(net[1]) elif "mid" in net[0]: append_ca(net[1]) # get projection matrices self.ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] self.projection_matrices = [l.to_v for l in self.ca_clip_layers] self.og_matrices = [copy.deepcopy(l.to_v) for l in self.ca_clip_layers] if self.with_to_k: self.projection_matrices = self.projection_matrices + [l.to_k for l in self.ca_clip_layers] self.og_matrices = self.og_matrices + [copy.deepcopy(l.to_k) for l in self.ca_clip_layers] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def edit_model( self, source_prompt: str, destination_prompt: str, lamb: float = 0.1, restart_params: bool = True, ): r""" Apply model editing via closed-form solution (see Eq. 5 in the TIME paper https://arxiv.org/abs/2303.08084) Args: source_prompt (`str`): The source prompt containing the concept to be edited. destination_prompt (`str`): The destination prompt. Must contain all words from source_prompt with additional ones to specify the target edit. lamb (`float`, *optional*, defaults to 0.1): The lambda parameter specifying the regularization intesity. Smaller values increase the editing power. restart_params (`bool`, *optional*, defaults to True): Restart the model parameters to their pre-trained version before editing. This is done to avoid edit compounding. When it is False, edits accumulate. """ # restart LDM parameters if restart_params: num_ca_clip_layers = len(self.ca_clip_layers) for idx_, l in enumerate(self.ca_clip_layers): l.to_v = copy.deepcopy(self.og_matrices[idx_]) self.projection_matrices[idx_] = l.to_v if self.with_to_k: l.to_k = copy.deepcopy(self.og_matrices[num_ca_clip_layers + idx_]) self.projection_matrices[num_ca_clip_layers + idx_] = l.to_k # set up sentences old_texts = [source_prompt] new_texts = [destination_prompt] # add augmentations base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] for aug in self.with_augs: old_texts.append(aug + base) base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] for aug in self.with_augs: new_texts.append(aug + base) # prepare input k* and v* old_embs, new_embs = [], [] for old_text, new_text in zip(old_texts, new_texts): text_input = self.tokenizer( [old_text, new_text], padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] old_emb, new_emb = text_embeddings old_embs.append(old_emb) new_embs.append(new_emb) # identify corresponding destinations for each token in old_emb idxs_replaces = [] for old_text, new_text in zip(old_texts, new_texts): tokens_a = self.tokenizer(old_text).input_ids tokens_b = self.tokenizer(new_text).input_ids tokens_a = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_a] tokens_b = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_b] num_orig_tokens = len(tokens_a) idxs_replace = [] j = 0 for i in range(num_orig_tokens): curr_token = tokens_a[i] while tokens_b[j] != curr_token: j += 1 idxs_replace.append(j) j += 1 while j < 77: idxs_replace.append(j) j += 1 while len(idxs_replace) < 77: idxs_replace.append(76) idxs_replaces.append(idxs_replace) # prepare batch: for each pair of setences, old context and new values contexts, valuess = [], [] for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): context = old_emb.detach() values = [] with torch.no_grad(): for layer in self.projection_matrices: values.append(layer(new_emb[idxs_replace]).detach()) contexts.append(context) valuess.append(values) # edit the model for layer_num in range(len(self.projection_matrices)): # mat1 = \lambda W + \sum{v k^T} mat1 = lamb * self.projection_matrices[layer_num].weight # mat2 = \lambda I + \sum{k k^T} mat2 = lamb * torch.eye( self.projection_matrices[layer_num].weight.shape[1], device=self.projection_matrices[layer_num].weight.device, ) # aggregate sums for mat1, mat2 for context, values in zip(contexts, valuess): context_vector = context.reshape(context.shape[0], context.shape[1], 1) context_vector_T = context.reshape(context.shape[0], 1, context.shape[1]) value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1) for_mat1 = (value_vector @ context_vector_T).sum(dim=0) for_mat2 = (context_vector @ context_vector_T).sum(dim=0) mat1 += for_mat1 mat2 += for_mat2 # update projection matrix self.projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2)) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py ================================================ # Copyright 2023 MultiDiffusion Authors and The HuggingFace Team. All rights reserved." # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler >>> model_ckpt = "stabilityai/stable-diffusion-2-base" >>> scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler") >>> pipe = StableDiffusionPanoramaPipeline.from_pretrained( ... model_ckpt, scheduler=scheduler, torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of the dolomites" >>> image = pipe(prompt).images[0] ``` """ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation". This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). To generate panorama-like images, be sure to pass the `width` parameter accordingly when using the pipeline. Our recommendation for the `width` value is 2048. This is the default value of the `width` parameter for this pipeline. Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. The original work on Multi Diffsion used the [`DDIMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) # if panorama's height/width < window_size, num_blocks of height/width should return 1 panorama_height /= 8 panorama_width /= 8 num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1 total_num_blocks = int(num_blocks_height * num_blocks_width) views = [] for i in range(total_num_blocks): h_start = int((i // num_blocks_width) * stride) h_end = h_start + window_size w_start = int((i % num_blocks_width) * stride) w_end = w_start + window_size views.append((h_start, h_end, w_start, w_end)) return views @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = 512, width: Optional[int] = 2048, num_inference_steps: int = 50, guidance_scale: float = 7.5, view_batch_size: int = 1, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to 512: The height in pixels of the generated image. width (`int`, *optional*, defaults to 2048): The width in pixels of the generated image. The width is kept to a high number because the pipeline is supposed to be used for generating panorama-like images. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. view_batch_size (`int`, *optional*, defaults to 1): The batch size to denoise splited views. For some GPUs with high performance, higher view batch size can speedup the generation and increase the VRAM usage. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Define panorama grid and initialize views for synthesis. # prepare batch grid views = self.get_views(height, width) views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch) count = torch.zeros_like(latents) value = torch.zeros_like(latents) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop # Each denoising step also includes refinement of the latents with respect to the # views. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): count.zero_() value.zero_() # generate views # Here, we iterate through different spatial crops of the latents and denoise them. These # denoised (latent) crops are then averaged to produce the final latent # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 # Batch views denoise for j, batch_view in enumerate(views_batch): vb_size = len(batch_view) # get the latents corresponding to the current view coordinates latents_for_view = torch.cat( [latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view] ) # rematch block's scheduler status self.scheduler.__dict__.update(views_scheduler_status[j]) # expand the latents if we are doing classifier free guidance latent_model_input = ( latents_for_view.repeat_interleave(2, dim=0) if do_classifier_free_guidance else latents_for_view ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # repeat prompt_embeds for batch prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds_input, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_denoised_batch = self.scheduler.step( noise_pred, t, latents_for_view, **extra_step_kwargs ).prev_sample # save views scheduler status after sample views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) # extract value from batch for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( latents_denoised_batch.chunk(vb_size), batch_view ): value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = torch.where(count > 0, value / count, value) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py ================================================ # Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import DDPMParallelScheduler >>> from diffusers import StableDiffusionParadigmsPipeline >>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler") >>> pipe = StableDiffusionParadigmsPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> ngpu, batch_per_device = torch.cuda.device_count(), 5 >>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)]) >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0] ``` """ class StableDiffusionParadigmsPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Parallelized version of StableDiffusionPipeline, based on the paper https://arxiv.org/abs/2305.16317 This pipeline parallelizes the denoising steps to generate a single image faster (more akin to model parallelism). Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # attribute to wrap the unet with torch.nn.DataParallel when running multiple denoising steps on multiple GPUs self.wrapped_unet = self.unet # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _cumsum(self, input, dim, debug=False): if debug: # cumsum_cuda_kernel does not have a deterministic implementation # so perform cumsum on cpu for debugging purposes return torch.cumsum(input.cpu().float(), dim=dim).to(input.device) else: return torch.cumsum(input, dim=dim) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, parallel: int = 10, tolerance: float = 0.1, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, debug: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. parallel (`int`, *optional*, defaults to 10): The batch size to use when doing parallel sampling. More parallelism may lead to faster inference but requires higher memory usage and also can require more total FLOPs. tolerance (`float`, *optional*, defaults to 0.1): The error tolerance for determining when to slide the batch window forward for parallel sampling. Lower tolerance usually leads to less/no degradation. Higher tolerance is faster but can risk degradation of sample quality. The tolerance is specified as a ratio of the scheduler's noise magnitude. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). debug (`bool`, *optional*, defaults to `False`): Whether or not to run in debug mode. In debug mode, torch.cumsum is evaluated using the CPU. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs.pop("generator", None) # # 7. Denoising loop scheduler = self.scheduler parallel = min(parallel, len(scheduler.timesteps)) begin_idx = 0 end_idx = parallel latents_time_evolution_buffer = torch.stack([latents] * (len(scheduler.timesteps) + 1)) # We must make sure the noise of stochastic schedulers such as DDPM is sampled only once per timestep. # Sampling inside the parallel denoising loop will mess this up, so we pre-sample the noise vectors outside the denoising loop. noise_array = torch.zeros_like(latents_time_evolution_buffer) for j in range(len(scheduler.timesteps)): base_noise = randn_tensor( shape=latents.shape, generator=generator, device=latents.device, dtype=prompt_embeds.dtype ) noise = (self.scheduler._get_variance(scheduler.timesteps[j]) ** 0.5) * base_noise noise_array[j] = noise.clone() # We specify the error tolerance as a ratio of the scheduler's noise magnitude. We similarly compute the error tolerance # outside of the denoising loop to avoid recomputing it at every step. # We will be dividing the norm of the noise, so we store its inverse here to avoid a division at every step. inverse_variance_norm = 1.0 / torch.tensor( [scheduler._get_variance(scheduler.timesteps[j]) for j in range(len(scheduler.timesteps))] + [0] ).to(noise_array.device) latent_dim = noise_array[0, 0].numel() inverse_variance_norm = inverse_variance_norm[:, None] / latent_dim scaled_tolerance = tolerance**2 with self.progress_bar(total=num_inference_steps) as progress_bar: steps = 0 while begin_idx < len(scheduler.timesteps): # these have shape (parallel_dim, 2*batch_size, ...) # parallel_len is at most parallel, but could be less if we are at the end of the timesteps # we are processing batch window of timesteps spanning [begin_idx, end_idx) parallel_len = end_idx - begin_idx block_prompt_embeds = torch.stack([prompt_embeds] * parallel_len) block_latents = latents_time_evolution_buffer[begin_idx:end_idx] block_t = scheduler.timesteps[begin_idx:end_idx, None].repeat(1, batch_size * num_images_per_prompt) t_vec = block_t if do_classifier_free_guidance: t_vec = t_vec.repeat(1, 2) # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([block_latents] * 2, dim=1) if do_classifier_free_guidance else block_latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t_vec) # if parallel_len is small, no need to use multiple GPUs net = self.wrapped_unet if parallel_len > 3 else self.unet # predict the noise residual, shape is now [parallel_len * 2 * batch_size * num_images_per_prompt, ...] model_output = net( latent_model_input.flatten(0, 1), t_vec.flatten(0, 1), encoder_hidden_states=block_prompt_embeds.flatten(0, 1), cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] per_latent_shape = model_output.shape[1:] if do_classifier_free_guidance: model_output = model_output.reshape( parallel_len, 2, batch_size * num_images_per_prompt, *per_latent_shape ) noise_pred_uncond, noise_pred_text = model_output[:, 0], model_output[:, 1] model_output = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) model_output = model_output.reshape( parallel_len * batch_size * num_images_per_prompt, *per_latent_shape ) block_latents_denoise = scheduler.batch_step_no_noise( model_output=model_output, timesteps=block_t.flatten(0, 1), sample=block_latents.flatten(0, 1), **extra_step_kwargs, ).reshape(block_latents.shape) # back to shape (parallel_dim, batch_size, ...) # now we want to add the pre-sampled noise # parallel sampling algorithm requires computing the cumulative drift from the beginning # of the window, so we need to compute cumulative sum of the deltas and the pre-sampled noises. delta = block_latents_denoise - block_latents cumulative_delta = self._cumsum(delta, dim=0, debug=debug) cumulative_noise = self._cumsum(noise_array[begin_idx:end_idx], dim=0, debug=debug) # if we are using an ODE-like scheduler (like DDIM), we don't want to add noise if scheduler._is_ode_scheduler: cumulative_noise = 0 block_latents_new = ( latents_time_evolution_buffer[begin_idx][None,] + cumulative_delta + cumulative_noise ) cur_error = torch.linalg.norm( (block_latents_new - latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1]).reshape( parallel_len, batch_size * num_images_per_prompt, -1 ), dim=-1, ).pow(2) error_ratio = cur_error * inverse_variance_norm[begin_idx + 1 : end_idx + 1] # find the first index of the vector error_ratio that is greater than error tolerance # we can shift the window for the next iteration up to this index error_ratio = torch.nn.functional.pad( error_ratio, (0, 0, 0, 1), value=1e9 ) # handle the case when everything is below ratio, by padding the end of parallel_len dimension any_error_at_time = torch.max(error_ratio > scaled_tolerance, dim=1).values.int() ind = torch.argmax(any_error_at_time).item() # compute the new begin and end idxs for the window new_begin_idx = begin_idx + min(1 + ind, parallel) new_end_idx = min(new_begin_idx + parallel, len(scheduler.timesteps)) # store the computed latents for the current window in the global buffer latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1] = block_latents_new # initialize the new sliding window latents with the end of the current window, # should be better than random initialization latents_time_evolution_buffer[end_idx : new_end_idx + 1] = latents_time_evolution_buffer[end_idx][ None, ] steps += 1 progress_bar.update(new_begin_idx - begin_idx) if callback is not None and steps % callback_steps == 0: callback(begin_idx, block_t[begin_idx], latents_time_evolution_buffer[begin_idx]) begin_idx = new_begin_idx end_idx = new_end_idx latents = latents_time_evolution_buffer[-1] if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py ================================================ # Copyright 2023 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import ( BlipForConditionalGeneration, BlipProcessor, CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, ) from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from ...utils import ( PIL_INTERPOLATION, BaseOutput, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): """ Output class for Stable Diffusion pipelines. Args: latents (`torch.FloatTensor`) inverted latents tensor images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ latents: torch.FloatTensor images: Union[List[PIL.Image.Image], np.ndarray] EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline >>> def download(embedding_url, local_filepath): ... r = requests.get(embedding_url) ... with open(local_filepath, "wb") as f: ... f.write(r.content) >>> model_ckpt = "CompVis/stable-diffusion-v1-4" >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.to("cuda") >>> prompt = "a high resolution painting of a cat in the style of van gough" >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" >>> for url in [source_emb_url, target_emb_url]: ... download(url, url.split("/")[-1]) >>> src_embeds = torch.load(source_emb_url.split("/")[-1]) >>> target_embeds = torch.load(target_emb_url.split("/")[-1]) >>> images = pipeline( ... prompt, ... source_embeds=src_embeds, ... target_embeds=target_embeds, ... num_inference_steps=50, ... cross_attention_guidance_amount=0.15, ... ).images >>> images[0].save("edited_image_dog.png") ``` """ EXAMPLE_INVERT_DOC_STRING = """ Examples: ```py >>> import torch >>> from transformers import BlipForConditionalGeneration, BlipProcessor >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline >>> import requests >>> from PIL import Image >>> captioner_id = "Salesforce/blip-image-captioning-base" >>> processor = BlipProcessor.from_pretrained(captioner_id) >>> model = BlipForConditionalGeneration.from_pretrained( ... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True ... ) >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4" >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( ... sd_model_ckpt, ... caption_generator=model, ... caption_processor=processor, ... torch_dtype=torch.float16, ... safety_checker=None, ... ) >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) >>> pipeline.enable_model_cpu_offload() >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) >>> # generate caption >>> caption = pipeline.generate_caption(raw_image) >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii" >>> inv_latents = pipeline.invert(caption, image=raw_image).latents >>> # we need to generate source and target embeds >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] >>> source_embeds = pipeline.get_embeds(source_prompts) >>> target_embeds = pipeline.get_embeds(target_prompts) >>> # the latents can then be used to edit a real image >>> # when using Stable Diffusion 2 or other models that use v-prediction >>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion >>> image = pipeline( ... caption, ... source_embeds=source_embeds, ... target_embeds=target_embeds, ... num_inference_steps=50, ... cross_attention_guidance_amount=0.15, ... generator=generator, ... latents=inv_latents, ... negative_prompt=caption, ... ).images[0] >>> image.save("edited_image.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def prepare_unet(unet: UNet2DConditionModel): """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" pix2pix_zero_attn_procs = {} for name in unet.attn_processors.keys(): module_name = name.replace(".processor", "") module = unet.get_submodule(module_name) if "attn2" in name: pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True) module.requires_grad_(True) else: pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False) module.requires_grad_(False) unet.set_attn_processor(pix2pix_zero_attn_procs) return unet class Pix2PixZeroL2Loss: def __init__(self): self.loss = 0.0 def compute_loss(self, predictions, targets): self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) class Pix2PixZeroAttnProcessor: """An attention processor class to store the attention weights. In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" def __init__(self, is_pix2pix_zero=False): self.is_pix2pix_zero = is_pix2pix_zero if self.is_pix2pix_zero: self.reference_cross_attn_map = {} def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, timestep=None, loss=None, ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) if self.is_pix2pix_zero and timestep is not None: # new bookkeeping to save the attention weights. if loss is None: self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu() # compute loss elif loss is not None: prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item()) loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device)) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): r""" Pipeline for pixel-levl image editing using Pix2Pix Zero. Based on Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. requires_safety_checker (bool): Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the pipeline publicly. """ _optional_components = [ "safety_checker", "feature_extractor", "caption_generator", "caption_processor", "inverse_scheduler", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], feature_extractor: CLIPImageProcessor, safety_checker: StableDiffusionSafetyChecker, inverse_scheduler: DDIMInverseScheduler, caption_generator: BlipForConditionalGeneration, caption_processor: BlipProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, caption_processor=caption_processor, caption_generator=caption_generator, inverse_scheduler=inverse_scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") hook = None for cpu_offloaded_model in [self.vae, self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, source_embeds, target_embeds, callback_steps, prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if source_embeds is None and target_embeds is None: raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.") if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def generate_caption(self, images): """Generates caption for a given image.""" text = "a photography of" prev_device = self.caption_generator.device device = self._execution_device inputs = self.caption_processor(images, text, return_tensors="pt").to( device=device, dtype=self.caption_generator.dtype ) self.caption_generator.to(device) outputs = self.caption_generator.generate(**inputs, max_new_tokens=128) # offload caption generator self.caption_generator.to(prev_device) caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] return caption def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor): """Constructs the edit direction to steer the image generation process semantically.""" return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) @torch.no_grad() def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.FloatTensor: num_prompts = len(prompt) embeds = [] for i in range(0, num_prompts, batch_size): prompt_slice = prompt[i : i + batch_size] input_ids = self.tokenizer( prompt_slice, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ).input_ids input_ids = input_ids.to(self.text_encoder.device) embeds.append(self.text_encoder(input_ids)[0]) return torch.cat(embeds, dim=0).mean(0)[None] def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) if image.shape[1] == 4: latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] latents = torch.cat(latents, dim=0) else: latents = self.vae.encode(image).latent_dist.sample(generator) latents = self.vae.config.scaling_factor * latents if batch_size != latents.shape[0]: if batch_size % latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_latents_per_image = batch_size // latents.shape[0] latents = torch.cat([latents] * additional_latents_per_image, dim=0) else: raise ValueError( f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." ) else: latents = torch.cat([latents], dim=0) return latents def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): pred_type = self.inverse_scheduler.config.prediction_type alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if pred_type == "epsilon": return model_output elif pred_type == "sample": return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) elif pred_type == "v_prediction": return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" ) def auto_corr_loss(self, hidden_states, generator=None): reg_loss = 0.0 for i in range(hidden_states.shape[0]): for j in range(hidden_states.shape[1]): noise = hidden_states[i : i + 1, j : j + 1, :, :] while True: roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) return reg_loss def kl_divergence(self, hidden_states): mean = hidden_states.mean() var = hidden_states.var() return var + mean**2 - 1 - torch.log(var + 1e-7) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, source_embeds: torch.Tensor = None, target_embeds: torch.Tensor = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, cross_attention_guidance_amount: float = 0.1, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. source_embeds (`torch.Tensor`): Source concept embeddings. Generation of the embeddings as per the [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. target_embeds (`torch.Tensor`): Target concept embeddings. Generation of the embeddings as per the [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. cross_attention_guidance_amount (`float`, defaults to 0.1): Amount of guidance needed from the reference cross-attention maps. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Define the spatial resolutions. height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, source_embeds, target_embeds, callback_steps, prompt_embeds, ) # 3. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Generate the inverted noise from the input image or any other image # generated from the input prompt. num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) latents_init = latents.clone() # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Rejig the UNet so that we can obtain the cross-attenion maps and # use them for guiding the subsequent image generation. self.unet = prepare_unet(self.unet) # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs={"timestep": t}, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Compute the edit directions. edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) # 9. Edit the prompt embeddings as per the edit directions discovered. prompt_embeds_edit = prompt_embeds.clone() prompt_embeds_edit[1:2] += edit_direction # 10. Second denoising loop to generate the edited image. latents = latents_init num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # we want to learn the latent such that it steers the generation # process towards the edited direction, so make the make initial # noise learnable x_in = latent_model_input.detach().clone() x_in.requires_grad = True # optimizer opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) with torch.enable_grad(): # initialize loss loss = Pix2PixZeroL2Loss() # predict the noise residual noise_pred = self.unet( x_in, t, encoder_hidden_states=prompt_embeds_edit.detach(), cross_attention_kwargs={"timestep": t, "loss": loss}, ).sample loss.loss.backward(retain_graph=False) opt.step() # recompute the noise noise_pred = self.unet( x_in.detach(), t, encoder_hidden_states=prompt_embeds_edit, cross_attention_kwargs={"timestep": None}, ).sample latents = x_in.detach().chunk(2)[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) @torch.no_grad() @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) def invert( self, prompt: Optional[str] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, num_inference_steps: int = 50, guidance_scale: float = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, cross_attention_guidance_amount: float = 0.1, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, lambda_auto_corr: float = 20.0, lambda_kl: float = 20.0, num_reg_steps: int = 5, num_auto_corr_rolls: int = 5, ): r""" Function used to generate inverted latents given a prompt and image. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be used for conditioning. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 1): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. cross_attention_guidance_amount (`float`, defaults to 0.1): Amount of guidance needed from the reference cross-attention maps. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. lambda_auto_corr (`float`, *optional*, defaults to 20.0): Lambda parameter to control auto correction lambda_kl (`float`, *optional*, defaults to 20.0): Lambda parameter to control Kullback–Leibler divergence output num_reg_steps (`int`, *optional*, defaults to 5): Number of regularization loss steps num_auto_corr_rolls (`int`, *optional*, defaults to 5): Number of auto correction roll steps Examples: Returns: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted latents tensor and then second is the corresponding decoded image. """ # 1. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image image = self.image_processor.preprocess(image) # 4. Prepare latent variables latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) # 5. Encode input prompt num_images_per_prompt = 1 prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, prompt_embeds=prompt_embeds, ) # 4. Prepare timesteps self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.inverse_scheduler.timesteps # 6. Rejig the UNet so that we can obtain the cross-attenion maps and # use them for guiding the subsequent image generation. self.unet = prepare_unet(self.unet) # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order with self.progress_bar(total=num_inference_steps - 1) as progress_bar: for i, t in enumerate(timesteps[:-1]): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs={"timestep": t}, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # regularization of the noise prediction with torch.enable_grad(): for _ in range(num_reg_steps): if lambda_auto_corr > 0: for _ in range(num_auto_corr_rolls): var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_ac = self.auto_corr_loss(var_epsilon, generator=generator) l_ac.backward() grad = var.grad.detach() / num_auto_corr_rolls noise_pred = noise_pred - lambda_auto_corr * grad if lambda_kl > 0: var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_kld = self.kl_divergence(var_epsilon) l_kld.backward() grad = var.grad.detach() noise_pred = noise_pred - lambda_kl * grad noise_pred = noise_pred.detach() # compute the previous noisy sample x_t -> x_t-1 latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) inverted_latents = latents.detach().clone() # 8. Post-processing image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (inverted_latents, image) return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py ================================================ # Copyright 2023 Susung Hong and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionSAGPipeline >>> pipe = StableDiffusionSAGPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, sag_scale=0.75).images[0] ``` """ # processes and stores attention probabilities class CrossAttnStoreProcessor: def __init__(self): self.attention_probs = None def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) self.attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(self.attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states # Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, sag_scale: float = 0.75, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. sag_scale (`float`, *optional*, defaults to 0.75): SAG scale as defined in [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance] (https://arxiv.org/abs/2210.00939). `sag_scale` is defined as `s_s` of equation (24) of SAG paper: https://arxiv.org/pdf/2210.00939.pdf. Typically chosen between [0, 1.0] for better quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # and `sag_scale` is` `s` of equation (16) # of the self-attentnion guidance paper: https://arxiv.org/pdf/2210.00939.pdf # `sag_scale = 0` means no self-attention guidance do_self_attention_guidance = sag_scale > 0.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop store_processor = CrossAttnStoreProcessor() self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order map_size = None def get_map_size(module, input, output): nonlocal map_size map_size = output[0].shape[-2:] with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # perform self-attention guidance with the stored self-attentnion map if do_self_attention_guidance: # classifier-free guidance produces two chunks of attention map # and we only use unconditional one according to equation (25) # in https://arxiv.org/pdf/2210.00939.pdf if do_classifier_free_guidance: # DDIM-like prediction of x0 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) # get the stored attention maps uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) # self-attention-based degrading of latents degraded_latents = self.sag_masking( pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t) ) uncond_emb, _ = prompt_embeds.chunk(2) # forward and give guidance degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) else: # DDIM-like prediction of x0 pred_x0 = self.pred_x0(latents, noise_pred, t) # get the stored attention maps cond_attn = store_processor.attention_probs # self-attention-based degrading of latents degraded_latents = self.sag_masking( pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t) ) # forward and give guidance degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample noise_pred += sag_scale * (noise_pred - degraded_pred) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) def sag_masking(self, original_latents, attn_map, map_size, t, eps): # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf bh, hw1, hw2 = attn_map.shape b, latent_channel, latent_h, latent_w = original_latents.shape h = self.unet.config.attention_head_dim if isinstance(h, list): h = h[-1] # Produce attention mask attn_map = attn_map.reshape(b, h, hw1, hw2) attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 attn_mask = ( attn_mask.reshape(b, map_size[0], map_size[1]) .unsqueeze(1) .repeat(1, latent_channel, 1, 1) .type(attn_map.dtype) ) attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) # Blur according to the self-attention mask degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) # Noise it again to match the noise level degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) return degraded_latents # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) def pred_x0(self, sample, model_output, timestep): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if self.scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif self.scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # predict V model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," " or `v_prediction`" ) return pred_original_sample def pred_epsilon(self, sample, model_output, timestep): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if self.scheduler.config.prediction_type == "epsilon": pred_eps = model_output elif self.scheduler.config.prediction_type == "sample": pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) elif self.scheduler.config.prediction_type == "v_prediction": pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," " or `v_prediction`" ) return pred_eps # Gaussian blur def gaussian_blur_2d(img, kernel_size, sigma): ksize_half = (kernel_size - 1) * 0.5 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) pdf = torch.exp(-0.5 * (x / sigma).pow(2)) x_kernel = pdf / pdf.sum() x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] img = F.pad(img, padding, mode="reflect") img = F.conv2d(img, kernel2d, groups=img.shape[-3]) return img ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image super-resolution using Stable Diffusion 2. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. low_res_scheduler ([`SchedulerMixin`]): A scheduler used to add initial noise to the low res conditioning image. It must be an instance of [`DDPMScheduler`]. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ _optional_components = ["watermarker", "safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, low_res_scheduler: DDPMScheduler, scheduler: KarrasDiffusionSchedulers, safety_checker: Optional[Any] = None, feature_extractor: Optional[CLIPImageProcessor] = None, watermarker: Optional[Any] = None, max_noise_level: int = 350, ): super().__init__() if hasattr( vae, "config" ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate is_vae_scaling_factor_set_to_0_08333 = ( hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 ) if not is_vae_scaling_factor_set_to_0_08333: deprecation_message = ( "The configuration file of the vae does not contain `scaling_factor` or it is set to" f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to" " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to" " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging" " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file" ) deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) vae.register_to_config(scaling_factor=0.08333) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, low_res_scheduler=low_res_scheduler, scheduler=scheduler, safety_checker=safety_checker, watermarker=watermarker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") self.register_to_config(max_noise_level=max_noise_level) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, image, noise_level, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, np.ndarray) and not isinstance(image, list) ): raise ValueError( f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" ) # verify batch size of prompt and image are same if image is a list or tensor or numpy array if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): if isinstance(prompt, str): batch_size = 1 else: batch_size = len(prompt) if isinstance(image, list): image_batch_size = len(image) else: image_batch_size = image.shape[0] if batch_size != image_batch_size: raise ValueError( f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." " Please make sure that passed `prompt` matches the batch size of `image`." ) # check noise level if noise_level > self.config.max_noise_level: raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height, width) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be upscaled. * num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: ```py >>> import requests >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import StableDiffusionUpscalePipeline >>> import torch >>> # load model and scheduler >>> model_id = "stabilityai/stable-diffusion-x4-upscaler" >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained( ... model_id, revision="fp16", torch_dtype=torch.float16 ... ) >>> pipeline = pipeline.to("cuda") >>> # let's download an image >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" >>> response = requests.get(url) >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") >>> low_res_img = low_res_img.resize((128, 128)) >>> prompt = "a white cat" >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] >>> upscaled_image.save("upsampled_cat.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs( prompt, image, noise_level, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image image = self.image_processor.preprocess(image) image = image.to(dtype=prompt_embeds.dtype, device=device) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Add noise to image noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) image = self.low_res_scheduler.add_noise(image, noise, noise_level) batch_multiplier = 2 if do_classifier_free_guidance else 1 image = torch.cat([image] * batch_multiplier * num_images_per_prompt) noise_level = torch.cat([noise_level] * image.shape[0]) # 6. Prepare latent variables height, width = image.shape[2:] num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, image], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, class_labels=noise_level, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 10. Post-processing # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() # post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # 11. Apply watermark if output_type == "pil" and self.watermarker is not None: image = self.watermarker.apply_watermark(image) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableUnCLIPPipeline >>> pipe = StableUnCLIPPipeline.from_pretrained( ... "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 ... ) # TODO update model path >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> images = pipe(prompt).images >>> images[0].save("astronaut_horse.png") ``` """ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): """ Pipeline for text-to-image generation using stable unCLIP. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior_tokenizer ([`CLIPTokenizer`]): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior_text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. prior_scheduler ([`KarrasDiffusionSchedulers`]): Scheduler used in the prior denoising process. image_normalizer ([`StableUnCLIPImageNormalizer`]): Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image embeddings after the noise has been applied. image_noising_scheduler ([`KarrasDiffusionSchedulers`]): Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined by `noise_level` in `StableUnCLIPPipeline.__call__`. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). text_encoder ([`CLIPTextModel`]): Frozen text-encoder. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`KarrasDiffusionSchedulers`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. """ # prior components prior_tokenizer: CLIPTokenizer prior_text_encoder: CLIPTextModelWithProjection prior: PriorTransformer prior_scheduler: KarrasDiffusionSchedulers # image noising components image_normalizer: StableUnCLIPImageNormalizer image_noising_scheduler: KarrasDiffusionSchedulers # regular denoising components tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL def __init__( self, # prior components prior_tokenizer: CLIPTokenizer, prior_text_encoder: CLIPTextModelWithProjection, prior: PriorTransformer, prior_scheduler: KarrasDiffusionSchedulers, # image noising components image_normalizer: StableUnCLIPImageNormalizer, image_noising_scheduler: KarrasDiffusionSchedulers, # regular denoising components tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, # vae vae: AutoencoderKL, ): super().__init__() self.register_modules( prior_tokenizer=prior_tokenizer, prior_text_encoder=prior_text_encoder, prior=prior, prior_scheduler=prior_scheduler, image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, vae=vae, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") # TODO: self.prior.post_process_latents and self.image_noiser.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list models = [ self.prior_text_encoder, self.text_encoder, self.unet, self.vae, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder def _encode_prior_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None, ): if text_model_output is None: batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.prior_tokenizer( prompt, padding="max_length", max_length=self.prior_tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.prior_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.prior_tokenizer.batch_decode( untruncated_ids[:, self.prior_tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.prior_tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.prior_tokenizer.model_max_length] prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device)) prompt_embeds = prior_text_encoder_output.text_embeds prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state else: batch_size = text_model_output[0].shape[0] prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1] text_mask = text_attention_mask prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size uncond_input = self.prior_tokenizer( uncond_tokens, padding="max_length", max_length=self.prior_tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_prior_text_encoder_output = self.prior_text_encoder( uncond_input.input_ids.to(device) ) negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds uncond_prior_text_encoder_hidden_states = ( negative_prompt_embeds_prior_text_encoder_output.last_hidden_state ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_prior_text_encoder_hidden_states.shape[1] uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat( 1, num_images_per_prompt, 1 ) uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prior_text_encoder_hidden_states = torch.cat( [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states] ) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, prior_text_encoder_hidden_states, text_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler def prepare_prior_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the prior_scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, noise_level, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." ) if prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." ) if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def noise_image_embeddings( self, image_embeds: torch.Tensor, noise_level: int, noise: Optional[torch.FloatTensor] = None, generator: Optional[torch.Generator] = None, ): """ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher `noise_level` increases the variance in the final un-noised images. The noise is applied in two ways 1. A noise schedule is applied directly to the embeddings 2. A vector of sinusoidal time embeddings are appended to the output. In both cases, the amount of noise is controlled by the same `noise_level`. The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. """ if noise is None: noise = randn_tensor( image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype ) noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) self.image_normalizer.to(image_embeds.device) image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) image_embeds = self.image_normalizer.unscale(image_embeds) noise_level = get_timestep_embedding( timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 ) # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, # but we might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. noise_level = noise_level.to(image_embeds.dtype) image_embeds = torch.cat((image_embeds, noise_level), 1) return image_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, # regular denoising process args prompt: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 20, guidance_scale: float = 10.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 0, # prior args prior_num_inference_steps: int = 25, prior_guidance_scale: float = 4.0, prior_latents: Optional[torch.FloatTensor] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 20): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 10.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to `0`): The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. prior_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps in the prior denoising process. More denoising steps usually lead to a higher quality image at the expense of slower inference. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale for the prior denoising process as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image embedding generation in the prior denoising process. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, height=height, width=width, callback_steps=callback_steps, noise_level=noise_level, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] batch_size = batch_size * num_images_per_prompt device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. prior_do_classifier_free_guidance = prior_guidance_scale > 1.0 # 3. Encode input prompt prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = self._encode_prior_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=prior_do_classifier_free_guidance, ) # 4. Prepare prior timesteps self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) prior_timesteps_tensor = self.prior_scheduler.timesteps # 5. Prepare prior latent variables embedding_dim = self.prior.config.embedding_dim prior_latents = self.prepare_latents( (batch_size, embedding_dim), prior_prompt_embeds.dtype, device, generator, prior_latents, self.prior_scheduler, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline prior_extra_step_kwargs = self.prepare_prior_extra_step_kwargs(generator, eta) # 7. Prior denoising loop for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents latent_model_input = self.prior_scheduler.scale_model_input(latent_model_input, t) predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prior_prompt_embeds, encoder_hidden_states=prior_text_encoder_hidden_states, attention_mask=prior_text_mask, ).predicted_image_embedding if prior_do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) prior_latents = self.prior_scheduler.step( predicted_image_embedding, timestep=t, sample=prior_latents, **prior_extra_step_kwargs, return_dict=False, )[0] if callback is not None and i % callback_steps == 0: callback(i, t, prior_latents) prior_latents = self.prior.post_process_latents(prior_latents) image_embeds = prior_latents # done prior # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 8. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 9. Prepare image embeddings image_embeds = self.noise_image_embeddings( image_embeds=image_embeds, noise_level=noise_level, generator=generator, ) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) # 10. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 11. Prepare latent variables num_channels_latents = self.unet.config.in_channels shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) latents = self.prepare_latents( shape=shape, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=latents, scheduler=self.scheduler, ) # 12. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 13. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=image_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.utils.import_utils import is_accelerate_available from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import StableUnCLIPImg2ImgPipeline >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( ... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16 ... ) # TODO update model path >>> pipe = pipe.to("cuda") >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) >>> prompt = "A fantasy landscape, trending on artstation" >>> images = pipe(prompt, init_image).images >>> images[0].save("fantasy_landscape.png") ``` """ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): """ Pipeline for text-guided image to image generation using stable unCLIP. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: feature_extractor ([`CLIPImageProcessor`]): Feature extractor for image pre-processing before being encoded. image_encoder ([`CLIPVisionModelWithProjection`]): CLIP vision model for encoding images. image_normalizer ([`StableUnCLIPImageNormalizer`]): Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image embeddings after the noise has been applied. image_noising_scheduler ([`KarrasDiffusionSchedulers`]): Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined by `noise_level` in `StableUnCLIPPipeline.__call__`. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). text_encoder ([`CLIPTextModel`]): Frozen text-encoder. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`KarrasDiffusionSchedulers`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. """ # image encoding components feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection # image noising components image_normalizer: StableUnCLIPImageNormalizer image_noising_scheduler: KarrasDiffusionSchedulers # regular denoising components tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL def __init__( self, # image encoding components feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, # image noising components image_normalizer: StableUnCLIPImageNormalizer, image_noising_scheduler: KarrasDiffusionSchedulers, # regular denoising components tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, # vae vae: AutoencoderKL, ): super().__init__() self.register_modules( feature_extractor=feature_extractor, image_encoder=image_encoder, image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, vae=vae, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list models = [ self.image_encoder, self.text_encoder, self.unet, self.vae, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def _encode_image( self, image, device, batch_size, num_images_per_prompt, do_classifier_free_guidance, noise_level, generator, image_embeds, ): dtype = next(self.image_encoder.parameters()).dtype if isinstance(image, PIL.Image.Image): # the image embedding should repeated so it matches the total batch size of the prompt repeat_by = batch_size else: # assume the image input is already properly batched and just needs to be repeated so # it matches the num_images_per_prompt. # # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched # `image_embeds`. If those happen to be common use cases, let's think harder about # what the expected dimensions of inputs should be and how we handle the encoding. repeat_by = num_images_per_prompt if image_embeds is None: if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeds = self.image_encoder(image).image_embeds image_embeds = self.noise_image_embeddings( image_embeds=image_embeds, noise_level=noise_level, generator=generator, ) # duplicate image embeddings for each generation per prompt, using mps friendly method image_embeds = image_embeds.unsqueeze(1) bs_embed, seq_len, _ = image_embeds.shape image_embeds = image_embeds.repeat(1, repeat_by, 1) image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1) image_embeds = image_embeds.squeeze(1) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) return image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, height, width, callback_steps, noise_level, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, image_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." ) if prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." ) if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." ) if image is not None and image_embeds is not None: raise ValueError( "Provide either `image` or `image_embeds`. Please make sure to define only one of the two." ) if image is None and image_embeds is None: raise ValueError( "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." ) if image is not None: if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings def noise_image_embeddings( self, image_embeds: torch.Tensor, noise_level: int, noise: Optional[torch.FloatTensor] = None, generator: Optional[torch.Generator] = None, ): """ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher `noise_level` increases the variance in the final un-noised images. The noise is applied in two ways 1. A noise schedule is applied directly to the embeddings 2. A vector of sinusoidal time embeddings are appended to the output. In both cases, the amount of noise is controlled by the same `noise_level`. The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. """ if noise is None: noise = randn_tensor( image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype ) noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) self.image_normalizer.to(image_embeds.device) image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) image_embeds = self.image_normalizer.unscale(image_embeds) noise_level = get_timestep_embedding( timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 ) # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, # but we might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. noise_level = noise_level.to(image_embeds.dtype) image_embeds = torch.cat((image_embeds, noise_level), 1) return image_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[torch.FloatTensor, PIL.Image.Image] = None, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 20, guidance_scale: float = 10, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 0, image_embeds: Optional[torch.FloatTensor] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be used or prompt is initialized to `""`. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the latents in the denoising process such as in the standard stable diffusion text guided image variation process. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 20): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 10.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to `0`): The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. image_embeds (`torch.FloatTensor`, *optional*): Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as `latents`. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if prompt is None and prompt_embeds is None: prompt = len(image) * [""] if isinstance(image, list) else "" # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, image=image, height=height, width=width, callback_steps=callback_steps, noise_level=noise_level, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, image_embeds=image_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] batch_size = batch_size * num_images_per_prompt device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Encoder input image noise_level = torch.tensor([noise_level], device=device) image_embeds = self._encode_image( image=image, device=device, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, noise_level=noise_level, generator=generator, image_embeds=image_embeds, ) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size=batch_size, num_channels_latents=num_channels_latents, height=height, width=width, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=latents, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=image_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 9. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/safety_checker.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) class StableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPConfig): super().__init__(config) self.vision_model = CLIPVisionModel(config.vision_config) self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) @torch.no_grad() def forward(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() result = [] batch_size = image_embeds.shape[0] for i in range(batch_size): result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 for concept_idx in range(len(special_cos_dist[0])): concept_cos = special_cos_dist[i][concept_idx] concept_threshold = self.special_care_embeds_weights[concept_idx].item() result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["special_scores"][concept_idx] > 0: result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) adjustment = 0.01 for concept_idx in range(len(cos_dist[0])): concept_cos = cos_dist[i][concept_idx] concept_threshold = self.concept_embeds_weights[concept_idx].item() result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["concept_scores"][concept_idx] > 0: result_img["bad_concepts"].append(concept_idx) result.append(result_img) has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if torch.is_tensor(images) or torch.is_tensor(images[0]): images[idx] = torch.zeros_like(images[idx]) # black image else: images[idx] = np.zeros(images[idx].shape) # black image if any(has_nsfw_concepts): logger.warning( "Potential NSFW content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts @torch.no_grad() def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) cos_dist = cosine_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nsfw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment # special_scores = special_scores.round(decimals=3) special_care = torch.any(special_scores > 0, dim=1) special_adjustment = special_care * 0.01 special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment # concept_scores = concept_scores.round(decimals=3) has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) images[has_nsfw_concepts] = 0.0 # black image return images, has_nsfw_concepts ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/safety_checker_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import jax import jax.numpy as jnp from flax import linen as nn from flax.core.frozen_dict import FrozenDict from transformers import CLIPConfig, FlaxPreTrainedModel from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule def jax_cosine_distance(emb_1, emb_2, eps=1e-12): norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T return jnp.matmul(norm_emb_1, norm_emb_2.T) class FlaxStableDiffusionSafetyCheckerModule(nn.Module): config: CLIPConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) self.special_care_embeds = self.param( "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) ) self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) def __call__(self, clip_input): pooled_output = self.vision_model(clip_input)[1] image_embeds = self.visual_projection(pooled_output) special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign image inputs adjustment = 0.0 special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment special_scores = jnp.round(special_scores, 3) is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) # Use a lower threshold if an image has any special care concept special_adjustment = is_special_care * 0.01 concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment concept_scores = jnp.round(concept_scores, 3) has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) return has_nsfw_concepts class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): config_class = CLIPConfig main_input_name = "clip_input" module_class = FlaxStableDiffusionSafetyCheckerModule def __init__( self, config: CLIPConfig, input_shape: Optional[Tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, **kwargs, ): if input_shape is None: input_shape = (1, 224, 224, 3) module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor clip_input = jax.random.normal(rng, input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, clip_input)["params"] return random_params def __call__( self, clip_input, params: dict = None, ): clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) return self.module.apply( {"params": params or self.params}, jnp.array(clip_input, dtype=jnp.float32), rngs={}, ) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): """ This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image embeddings. """ @register_to_config def __init__( self, embedding_dim: int = 768, ): super().__init__() self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) self.std = nn.Parameter(torch.ones(1, embedding_dim)) def to( self, torch_device: Optional[Union[str, torch.device]] = None, torch_dtype: Optional[torch.dtype] = None, ): self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) return self def scale(self, embeds): embeds = (embeds - self.mean) * 1.0 / self.std return embeds def unscale(self, embeds): embeds = (embeds * self.std) + self.mean return embeds ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion_safe/__init__.py ================================================ from dataclasses import dataclass from enum import Enum from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import BaseOutput, is_torch_available, is_transformers_available @dataclass class SafetyConfig(object): WEAK = { "sld_warmup_steps": 15, "sld_guidance_scale": 20, "sld_threshold": 0.0, "sld_momentum_scale": 0.0, "sld_mom_beta": 0.0, } MEDIUM = { "sld_warmup_steps": 10, "sld_guidance_scale": 1000, "sld_threshold": 0.01, "sld_momentum_scale": 0.3, "sld_mom_beta": 0.4, } STRONG = { "sld_warmup_steps": 7, "sld_guidance_scale": 2000, "sld_threshold": 0.025, "sld_momentum_scale": 0.5, "sld_mom_beta": 0.7, } MAX = { "sld_warmup_steps": 0, "sld_guidance_scale": 5000, "sld_threshold": 1.0, "sld_momentum_scale": 0.5, "sld_mom_beta": 0.7, } @dataclass class StableDiffusionSafePipelineOutput(BaseOutput): """ Output class for Safe Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" (nsfw) content, or `None` if no safety check was performed or no images were flagged. applied_safety_concept (`str`) The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] applied_safety_concept: Optional[str] if is_transformers_available() and is_torch_available(): from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe from .safety_checker import SafeStableDiffusionSafetyChecker ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py ================================================ import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionSafePipelineOutput from .safety_checker import SafeStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableDiffusionPipelineSafe(DiffusionPipeline): r""" Pipeline for text-to-image generation using Safe Latent Diffusion. The implementation is based on the [`StableDiffusionPipeline`] This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: SafeStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() safety_concept: Optional[str] = ( "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" " abuse, brutality, cruelty" ) if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self._safety_text_concept = safety_concept self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) @property def safety_concept(self): r""" Getter method for the safety concept used with SLD Returns: `str`: The text describing the safety concept """ return self._safety_text_concept @safety_concept.setter def safety_concept(self, concept): r""" Setter method for the safety concept used with SLD Args: concept (`str`): The text of the new safety concept """ self._safety_text_concept = concept def enable_sequential_cpu_offload(self): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device("cuda") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # Encode the safety concept text if enable_safety_guidance: safety_concept_input = self.tokenizer( [self._safety_text_concept], padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] # duplicate safety embeddings for each generation per prompt, using mps friendly method seq_len = safety_embeddings.shape[1] safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance + sld, we need to do three forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing three forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings]) else: # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype, enable_safety_guidance): if self.safety_checker is not None: images = image.copy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) flagged_images = np.zeros((2, *image.shape[1:])) if any(has_nsfw_concept): logger.warning( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead." f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}" ) for idx, has_nsfw_concept in enumerate(has_nsfw_concept): if has_nsfw_concept: flagged_images[idx] = images[idx] image[idx] = np.zeros(image[idx].shape) # black image else: has_nsfw_concept = None flagged_images = None return image, has_nsfw_concept, flagged_images # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def perform_safety_guidance( self, enable_safety_guidance, safety_momentum, noise_guidance, noise_pred_out, i, sld_guidance_scale, sld_warmup_steps, sld_threshold, sld_momentum_scale, sld_mom_beta, ): # Perform SLD guidance if enable_safety_guidance: if safety_momentum is None: safety_momentum = torch.zeros_like(noise_guidance) noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] noise_pred_safety_concept = noise_pred_out[2] # Equation 6 scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) # Equation 6 safety_concept_scale = torch.where( (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale ) # Equation 4 noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) # Equation 7 noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum # Equation 8 safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety if i >= sld_warmup_steps: # Warmup # Equation 3 noise_guidance = noise_guidance - noise_guidance_safety return noise_guidance, safety_momentum @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, sld_guidance_scale: Optional[float] = 1000, sld_warmup_steps: Optional[int] = 10, sld_threshold: Optional[float] = 0.01, sld_momentum_scale: Optional[float] = 0.3, sld_mom_beta: Optional[float] = 0.4, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. sld_guidance_scale (`float`, *optional*, defaults to 1000): Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). `sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be disabled. sld_warmup_steps (`int`, *optional*, defaults to 10): Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than `sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). sld_threshold (`float`, *optional*, defaults to 0.01): Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold` is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). sld_momentum_scale (`float`, *optional*, defaults to 0.3): Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0 momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). sld_mom_beta (`float`, *optional*, defaults to 0.4): Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance if not enable_safety_guidance: warnings.warn("Safety checker disabled!") # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) safety_momentum = None num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * (3 if enable_safety_guidance else 2)) if do_classifier_free_guidance else latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] # default classifier free guidance noise_guidance = noise_pred_text - noise_pred_uncond # Perform SLD guidance if enable_safety_guidance: if safety_momentum is None: safety_momentum = torch.zeros_like(noise_guidance) noise_pred_safety_concept = noise_pred_out[2] # Equation 6 scale = torch.clamp( torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 ) # Equation 6 safety_concept_scale = torch.where( (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale, ) # Equation 4 noise_guidance_safety = torch.mul( (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale ) # Equation 7 noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum # Equation 8 safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety if i >= sld_warmup_steps: # Warmup # Equation 3 noise_guidance = noise_guidance - noise_guidance_safety noise_pred = noise_pred_uncond + guidance_scale * noise_guidance # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Post-processing image = self.decode_latents(latents) # 9. Run safety checker image, has_nsfw_concept, flagged_images = self.run_safety_checker( image, device, prompt_embeds.dtype, enable_safety_guidance ) # 10. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) if flagged_images is not None: flagged_images = self.numpy_to_pil(flagged_images) if not return_dict: return ( image, has_nsfw_concept, self._safety_text_concept if enable_safety_guidance else None, flagged_images, ) return StableDiffusionSafePipelineOutput( images=image, nsfw_content_detected=has_nsfw_concept, applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, unsafe_images=flagged_images, ) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion_safe/safety_checker.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) class SafeStableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPConfig): super().__init__(config) self.vision_model = CLIPVisionModel(config.vision_config) self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) @torch.no_grad() def forward(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() result = [] batch_size = image_embeds.shape[0] for i in range(batch_size): result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 for concept_idx in range(len(special_cos_dist[0])): concept_cos = special_cos_dist[i][concept_idx] concept_threshold = self.special_care_embeds_weights[concept_idx].item() result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["special_scores"][concept_idx] > 0: result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) adjustment = 0.01 for concept_idx in range(len(cos_dist[0])): concept_cos = cos_dist[i][concept_idx] concept_threshold = self.concept_embeds_weights[concept_idx].item() result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["concept_scores"][concept_idx] > 0: result_img["bad_concepts"].append(concept_idx) result.append(result_img) has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] return images, has_nsfw_concepts @torch.no_grad() def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) cos_dist = cosine_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nsfw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment # special_scores = special_scores.round(decimals=3) special_care = torch.any(special_scores > 0, dim=1) special_adjustment = special_care * 0.01 special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment # concept_scores = concept_scores.round(decimals=3) has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) return images, has_nsfw_concepts ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion_xl/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from ...utils import BaseOutput, is_invisible_watermark_available, is_torch_available, is_transformers_available @dataclass class StableDiffusionXLPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ images: Union[List[PIL.Image.Image], np.ndarray] if is_transformers_available() and is_torch_available() and is_invisible_watermark_available(): from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionXLPipeline >>> pipe = StableDiffusionXLPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size self.watermark = StableDiffusionXLWatermarker() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) model_sequence = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) model_sequence.extend([self.unet, self.vae]) hook = None for cpu_offloaded_model in model_sequence: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def encode_prompt( self, prompt, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder( text_input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt negative_prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1 ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): TODO crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): TODO target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): TODO Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents return StableDiffusionXLPipelineOutput(images=image) image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionXLImg2ImgPipeline >>> from diffusers.utils import load_image >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" >>> init_image = load_image(url).convert("RGB") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, image=init_image).images[0] ``` """ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ _optional_components = ["tokenizer", "text_encoder"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.watermark = StableDiffusionXLWatermarker() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]: cpu_offload(cpu_offloaded_model, device) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) model_sequence = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) model_sequence.extend([self.unet, self.vae]) hook = None for cpu_offloaded_model in model_sequence: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder( text_input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt negative_prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1 ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") torch.cuda.empty_cache() image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: # make sure the VAE is in float32 mode, as it overflows in float16 image = image.float() self.vae.to(dtype=torch.float32) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) self.vae.to(dtype) init_latents = init_latents.to(dtype) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." ) elif expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) return add_time_ids, add_neg_time_ids @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, strength: float = 0.3, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): The image(s) to modify with the pipeline. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): TODO crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): TODO target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): TODO aesthetic_score (`float`, *optional*, defaults to 6.0): TODO negative_aesthetic_score (`float`, *optional*, defaults to 2.5): TDOO Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) height, width = latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 8. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype=prompt_embeds.dtype, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents return StableDiffusionXLPipelineOutput(images=image) image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/stable_diffusion_xl/watermark.py ================================================ import numpy as np import torch from imwatermark import WatermarkEncoder # Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66 WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] class StableDiffusionXLWatermarker: def __init__(self): self.watermark = WATERMARK_BITS self.encoder = WatermarkEncoder() self.encoder.set_watermark("bits", self.watermark) def apply_watermark(self, images: torch.FloatTensor): # can't encode images that are smaller than 256 if images.shape[-1] < 256: return images images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() images = [self.encoder.encode(image, "dwtDct") for image in images] images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) return images ================================================ FILE: 4DoF/diffusers/pipelines/stochastic_karras_ve/__init__.py ================================================ from .pipeline_stochastic_karras_ve import KarrasVePipeline ================================================ FILE: 4DoF/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...models import UNet2DModel from ...schedulers import KarrasVeScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class KarrasVePipeline(DiffusionPipeline): r""" Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 Parameters: unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`KarrasVeScheduler`]): Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image. """ # add type hints for linting unet: UNet2DModel scheduler: KarrasVeScheduler def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 50, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) model = self.unet # sample x_0 ~ N(0, sigma_0^2 * I) sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # here sigma_t == t_i from the paper sigma = self.scheduler.schedule[t] sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 # 1. Select temporarily increased noise level sigma_hat # 2. Add new noise to move from sample_i to sample_hat sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) # 3. Predict the noise residual given the noise magnitude `sigma_hat` # The model inputs and output are adjusted by following eq. (213) in [1]. model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample # 4. Evaluate dx/dt at sigma_hat # 5. Take Euler step from sigma to sigma_prev step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) if sigma_prev != 0: # 6. Apply 2nd order correction # The model inputs and output are adjusted by following eq. (213) in [1]. model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample step_output = self.scheduler.step_correct( model_output, sigma_hat, sigma_prev, sample_hat, step_output.prev_sample, step_output["derivative"], ) sample = step_output.prev_sample sample = (sample / 2 + 0.5).clamp(0, 1) image = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/text_to_video_synthesis/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import torch from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available @dataclass class TextToVideoSDPipelineOutput(BaseOutput): """ Output class for text to video pipelines. Args: frames (`List[np.ndarray]` or `torch.FloatTensor`) List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list denotes the video length i.e., the number of frames. """ frames: Union[List[np.ndarray], torch.FloatTensor] try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_text_to_video_synth import TextToVideoSDPipeline from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline # noqa: F401 from .pipeline_text_to_video_zero import TextToVideoZeroPipeline ================================================ FILE: 4DoF/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import TextToVideoSDPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import TextToVideoSDPipeline >>> from diffusers.utils import export_to_video >>> pipe = TextToVideoSDPipeline.from_pretrained( ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "Spiderman is surfing" >>> video_frames = pipe(prompt).frames >>> video_path = export_to_video(video_frames) >>> video_path ``` """ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 # reshape to ncfhw mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) # unnormalize back to [0,1] video = video.mul_(std).add_(mean) video.clamp_(0, 1) # prepare the final outputs i, c, f, h, w = video.shape images = video.permute(2, 3, 0, 4, 1).reshape( f, h, i * w, c ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c return images class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Same as Stable Diffusion 2. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet3DConditionModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample video = ( image[None, :] .reshape( ( batch_size, num_frames, -1, ) + image.shape[2:] ) .permute(0, 2, 1, 3, 4) ) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_frames: int = 16, num_inference_steps: int = 50, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated video. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated video. num_frames (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor num_images_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # reshape latents bsz, channel, frames, width, height = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # reshape latents back latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if output_type == "latent": return TextToVideoSDPipelineOutput(frames=latents) video_tensor = self.decode_latents(latents) if output_type == "pt": video = video_tensor else: video = tensor2vid(video_tensor) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (video,) return TextToVideoSDPipelineOutput(frames=video) ================================================ FILE: 4DoF/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPTextModel, CLIPTokenizer from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import TextToVideoSDPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler >>> from diffusers.utils import export_to_video >>> pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16) >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.to("cuda") >>> prompt = "spiderman running in the desert" >>> video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames >>> # safe low-res video >>> video_path = export_to_video(video_frames, output_video_path="./video_576_spiderman.mp4") >>> # let's offload the text-to-image model >>> pipe.to("cpu") >>> # and load the image-to-image model >>> pipe = DiffusionPipeline.from_pretrained( ... "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/15" ... ) >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # The VAE consumes A LOT of memory, let's make sure we run it in sliced mode >>> pipe.vae.enable_slicing() >>> # now let's upscale it >>> video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames] >>> # and denoise it >>> video_frames = pipe(prompt, video=video, strength=0.6).frames >>> video_path = export_to_video(video_frames, output_video_path="./video_1024_spiderman.mp4") >>> video_path ``` """ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 # reshape to ncfhw mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) # unnormalize back to [0,1] video = video.mul_(std).add_(mean) video.clamp_(0, 1) # prepare the final outputs i, c, f, h, w = video.shape images = video.permute(2, 3, 0, 4, 1).reshape( f, h, i * w, c ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c return images def preprocess_video(video): supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image) if isinstance(video, supported_formats): video = [video] elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): raise ValueError( f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}" ) if isinstance(video[0], PIL.Image.Image): video = [np.array(frame) for frame in video] if isinstance(video[0], np.ndarray): video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) if video.dtype == np.uint8: video = np.array(video).astype(np.float32) / 255.0 if video.ndim == 4: video = video[None, ...] video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3)) elif isinstance(video[0], torch.Tensor): video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0) # don't need any preprocess if the video is latents channel = video.shape[1] if channel == 4: return video # move channels before num_frames video = video.permute(0, 2, 1, 3, 4) # normalize video video = 2.0 * video - 1.0 return video class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Same as Stable Diffusion 2. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet3DConditionModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.vae, self.unet]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample video = ( image[None, :] .reshape( ( batch_size, num_frames, -1, ) + image.shape[2:] ) .permute(0, 2, 1, 3, 4) ) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, video, timestep, batch_size, dtype, device, generator=None): video = video.to(device=device, dtype=dtype) # change from (b, c, f, h, w) -> (b * f, c, w, h) bsz, channel, frames, width, height = video.shape video = video.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) if video.shape[1] == 4: init_latents = video else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(video).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents latents = latents[None, :].reshape((bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4) return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, video: Union[List[np.ndarray], torch.FloatTensor] = None, strength: float = 0.6, num_inference_steps: int = 50, guidance_scale: float = 15.0, negative_prompt: Optional[Union[str, List[str]]] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. instead. video: (`List[np.ndarray]` or `torch.FloatTensor`): `video` frames or tensor representing a video batch, that will be used as the starting point for the process. Can also accpet video latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated frames. """ # 0. Default height and width to unet num_images_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess video video = preprocess_video(video) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # reshape latents bsz, channel, frames, width, height = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # reshape latents back latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if output_type == "latent": return TextToVideoSDPipelineOutput(frames=latents) if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") video_tensor = self.decode_latents(latents) if output_type == "pt": video = video_tensor else: video = tensor2vid(video_tensor) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (video,) return TextToVideoSDPipelineOutput(frames=video) ================================================ FILE: 4DoF/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py ================================================ import copy from dataclasses import dataclass from typing import Callable, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from torch.nn.functional import grid_sample from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import BaseOutput def rearrange_0(tensor, f): F, C, H, W = tensor.size() tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4)) return tensor def rearrange_1(tensor): B, C, F, H, W = tensor.size() return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W)) def rearrange_3(tensor, f): F, D, C = tensor.size() return torch.reshape(tensor, (F // f, f, D, C)) def rearrange_4(tensor): B, F, D, C = tensor.size() return torch.reshape(tensor, (B * F, D, C)) class CrossFrameAttnProcessor: """ Cross frame attention processor. Each frame attends the first frame. Args: batch_size: The number that represents actual batch size, other than the frames. For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to 2, due to classifier-free guidance. """ def __init__(self, batch_size=2): self.batch_size = batch_size def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Cross Frame Attention if not is_cross_attention: video_length = key.size()[0] // self.batch_size first_frame_index = [0] * video_length # rearrange keys to have batch and frames in the 1st and 2nd dims respectively key = rearrange_3(key, video_length) key = key[:, first_frame_index] # rearrange values to have batch and frames in the 1st and 2nd dims respectively value = rearrange_3(value, video_length) value = value[:, first_frame_index] # rearrange back to original shape key = rearrange_4(key) value = rearrange_4(value) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class CrossFrameAttnProcessor2_0: """ Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0. Args: batch_size: The number that represents actual batch size, other than the frames. For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to 2, due to classifier-free guidance. """ def __init__(self, batch_size=2): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.batch_size = batch_size def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Cross Frame Attention if not is_cross_attention: video_length = key.size()[0] // self.batch_size first_frame_index = [0] * video_length # rearrange keys to have batch and frames in the 1st and 2nd dims respectively key = rearrange_3(key, video_length) key = key[:, first_frame_index] # rearrange values to have batch and frames in the 1st and 2nd dims respectively value = rearrange_3(value, video_length) value = value[:, first_frame_index] # rearrange back to original shape key = rearrange_4(key) value = rearrange_4(value) head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states @dataclass class TextToVideoPipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] def coords_grid(batch, ht, wd, device): # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) def warp_single_latent(latent, reference_flow): """ Warp latent of a single frame with given flow Args: latent: latent code of a single frame reference_flow: flow which to warp the latent with Returns: warped: warped latent """ _, _, H, W = reference_flow.size() _, _, h, w = latent.size() coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype) coords_t0 = coords0 + reference_flow coords_t0[:, 0] /= W coords_t0[:, 1] /= H coords_t0 = coords_t0 * 2.0 - 1.0 coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear") coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1)) warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection") return warped def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype): """ Create translation motion field Args: motion_field_strength_x: motion strength along x-axis motion_field_strength_y: motion strength along y-axis frame_ids: indexes of the frames the latents of which are being processed. This is needed when we perform chunk-by-chunk inference device: device dtype: dtype Returns: """ seq_length = len(frame_ids) reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype) for fr_idx in range(seq_length): reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx]) reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx]) return reference_flow def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents): """ Creates translation motion and warps the latents accordingly Args: motion_field_strength_x: motion strength along x-axis motion_field_strength_y: motion strength along y-axis frame_ids: indexes of the frames the latents of which are being processed. This is needed when we perform chunk-by-chunk inference latents: latent codes of frames Returns: warped_latents: warped latents """ motion_field = create_motion_field( motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, frame_ids=frame_ids, device=latents.device, dtype=latents.dtype, ) warped_latents = latents.clone().detach() for i in range(len(warped_latents)): warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None]) return warped_latents class TextToVideoZeroPipeline(StableDiffusionPipeline): r""" Pipeline for zero-shot text-to-video generation using Stable Diffusion. This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker ) processor = ( CrossFrameAttnProcessor2_0(batch_size=2) if hasattr(F, "scaled_dot_product_attention") else CrossFrameAttnProcessor(batch_size=2) ) self.unet.set_attn_processor(processor) def forward_loop(self, x_t0, t0, t1, generator): """ Perform ddpm forward process from time t0 to t1. This is the same as adding noise with corresponding variance. Args: x_t0: latent code at time t0 t0: t0 t1: t1 generator: torch.Generator object Returns: x_t1: forward process applied to x_t0 from time t0 to t1. """ eps = torch.randn(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps return x_t1 def backward_loop( self, latents, timesteps, prompt_embeds, guidance_scale, callback, callback_steps, num_warmup_steps, extra_step_kwargs, cross_attention_kwargs=None, ): """ Perform backward process given list of time steps Args: latents: Latents at time timesteps[0]. timesteps: time steps, along which to perform backward process. prompt_embeds: Pre-generated text embeddings guidance_scale: Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. extra_step_kwargs: extra_step_kwargs. cross_attention_kwargs: cross_attention_kwargs. num_warmup_steps: number of warmup steps. Returns: latents: latents of backward process output at time timesteps[-1] """ do_classifier_free_guidance = guidance_scale > 1.0 num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order with self.progress_bar(total=num_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) return latents.clone().detach() @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], video_length: Optional[int] = 8, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, motion_field_strength_x: float = 12, motion_field_strength_y: float = 12, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, t0: int = 44, t1: int = 47, frame_ids: Optional[List[int]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. video_length (`int`, *optional*, defaults to 8): The number of generated video frames height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"numpy"`): The output format of the generated image. Choose between `"latent"` and `"numpy"`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. motion_field_strength_x (`float`, *optional*, defaults to 12): Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. motion_field_strength_y (`float`, *optional*, defaults to 12): Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. t0 (`int`, *optional*, defaults to 44): Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. t1 (`int`, *optional*, defaults to 47): Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. frame_ids (`List[int]`, *optional*): Indexes of the frames that are being generated. This is used when generating longer videos chunk-by-chunk. Returns: [`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]: The output contains a ndarray of the generated images, when output_type != 'latent', otherwise a latent codes of generated image, and a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ assert video_length > 0 if frame_ids is None: frame_ids = list(range(video_length)) assert len(frame_ids) == video_length assert num_videos_per_prompt == 1 if isinstance(prompt, str): prompt = [prompt] if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] # Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt ) # Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # Perform the first backward process up to time T_1 x_1_t1 = self.backward_loop( timesteps=timesteps[: -t1 - 1], prompt_embeds=prompt_embeds, latents=latents, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps, ) scheduler_copy = copy.deepcopy(self.scheduler) # Perform the second backward process up to time T_0 x_1_t0 = self.backward_loop( timesteps=timesteps[-t1 - 1 : -t0 - 1], prompt_embeds=prompt_embeds, latents=x_1_t1, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=0, ) # Propagate first frame latents at time T_0 to remaining frames x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1) # Add motion in latents at time T_0 x_2k_t0 = create_motion_field_and_warp_latents( motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=x_2k_t0, frame_ids=frame_ids[1:], ) # Perform forward process up to time T_1 x_2k_t1 = self.forward_loop( x_t0=x_2k_t0, t0=timesteps[-t0 - 1].item(), t1=timesteps[-t1 - 1].item(), generator=generator, ) # Perform backward process from time T_1 to 0 x_1k_t1 = torch.cat([x_1_t1, x_2k_t1]) b, l, d = prompt_embeds.size() prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) self.scheduler = scheduler_copy x_1k_0 = self.backward_loop( timesteps=timesteps[-t1 - 1 :], prompt_embeds=prompt_embeds, latents=x_1k_t1, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=0, ) latents = x_1k_0 # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") torch.cuda.empty_cache() if output_type == "latent": image = latents has_nsfw_concept = None else: image = self.decode_latents(latents) # Run safety checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/diffusers/pipelines/unclip/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline else: from .pipeline_unclip import UnCLIPPipeline from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline from .text_proj import UnCLIPTextProjModel ================================================ FILE: 4DoF/diffusers/pipelines/unclip/pipeline_unclip.py ================================================ # Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Tuple, Union import torch from torch.nn import functional as F from transformers import CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import UnCLIPScheduler from ...utils import is_accelerate_available, logging, randn_tensor from .text_proj import UnCLIPTextProjModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class UnCLIPPipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using unCLIP This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. text_proj ([`UnCLIPTextProjModel`]): Utility class to prepare and combine the embeddings before they are passed to the decoder. decoder ([`UNet2DConditionModel`]): The decoder to invert the image embedding into an image. super_res_first ([`UNet2DModel`]): Super resolution unet. Used in all but the last step of the super resolution diffusion process. super_res_last ([`UNet2DModel`]): Super resolution unet. Used in the last step of the super resolution diffusion process. prior_scheduler ([`UnCLIPScheduler`]): Scheduler used in the prior denoising process. Just a modified DDPMScheduler. decoder_scheduler ([`UnCLIPScheduler`]): Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. super_res_scheduler ([`UnCLIPScheduler`]): Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. """ prior: PriorTransformer decoder: UNet2DConditionModel text_proj: UnCLIPTextProjModel text_encoder: CLIPTextModelWithProjection tokenizer: CLIPTokenizer super_res_first: UNet2DModel super_res_last: UNet2DModel prior_scheduler: UnCLIPScheduler decoder_scheduler: UnCLIPScheduler super_res_scheduler: UnCLIPScheduler def __init__( self, prior: PriorTransformer, decoder: UNet2DConditionModel, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_proj: UnCLIPTextProjModel, super_res_first: UNet2DModel, super_res_last: UNet2DModel, prior_scheduler: UnCLIPScheduler, decoder_scheduler: UnCLIPScheduler, super_res_scheduler: UnCLIPScheduler, ): super().__init__() self.register_modules( prior=prior, decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, text_proj=text_proj, super_res_first=super_res_first, super_res_last=super_res_last, prior_scheduler=prior_scheduler, decoder_scheduler=decoder_scheduler, super_res_scheduler=super_res_scheduler, ) def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None, ): if text_model_output is None: batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state else: batch_size = text_model_output[0].shape[0] prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1] text_mask = text_attention_mask prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list models = [ self.decoder, self.text_proj, self.text_encoder, self.super_res_first, self.super_res_last, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): return self.device for module in self.decoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, prior_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25, super_res_num_inference_steps: int = 7, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prior_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None, prior_guidance_scale: float = 4.0, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. This can only be left undefined if `text_model_output` and `text_attention_mask` is passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. prior_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps for the prior. More denoising steps usually lead to a higher quality image at the expense of slower inference. decoder_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality image at the expense of slower inference. super_res_num_inference_steps (`int`, *optional*, defaults to 7): The number of denoising steps for super resolution. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): Pre-generated noisy latents to be used as inputs for the prior. decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. decoder_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. text_model_output (`CLIPTextModelOutput`, *optional*): Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs can be passed for tasks like text embedding interpolations. Make sure to also pass `text_attention_mask` in this case. `prompt` can the be left to `None`. text_attention_mask (`torch.Tensor`, *optional*): Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention masks are necessary when passing `text_model_output`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ if prompt is not None: if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") else: batch_size = text_model_output[0].shape[0] device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask ) # prior self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) prior_timesteps_tensor = self.prior_scheduler.timesteps embedding_dim = self.prior.config.embedding_dim prior_latents = self.prepare_latents( (batch_size, embedding_dim), prompt_embeds.dtype, device, generator, prior_latents, self.prior_scheduler, ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == prior_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = prior_timesteps_tensor[i + 1] prior_latents = self.prior_scheduler.step( predicted_image_embedding, timestep=t, sample=prior_latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample prior_latents = self.prior.post_process_latents(prior_latents) image_embeddings = prior_latents # done prior # decoder text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, prompt_embeds=prompt_embeds, text_encoder_hidden_states=text_encoder_hidden_states, do_classifier_free_guidance=do_classifier_free_guidance, ) if device.type == "mps": # HACK: MPS: There is a panic when padding bool tensors, # so cast to int tensor for the pad and back to bool afterwards text_mask = text_mask.type(torch.int) decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) decoder_text_mask = decoder_text_mask.type(torch.bool) else: decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps num_channels_latents = self.decoder.config.in_channels height = self.decoder.config.sample_size width = self.decoder.config.sample_size decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, decoder_latents, self.decoder_scheduler, ) for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents noise_pred = self.decoder( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, class_labels=additive_clip_time_embeddings, attention_mask=decoder_text_mask, ).sample if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if i + 1 == decoder_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = decoder_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 decoder_latents = self.decoder_scheduler.step( noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample decoder_latents = decoder_latents.clamp(-1, 1) image_small = decoder_latents # done decoder # super res self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps channels = self.super_res_first.config.in_channels // 2 height = self.super_res_first.config.sample_size width = self.super_res_first.config.sample_size super_res_latents = self.prepare_latents( (batch_size, channels, height, width), image_small.dtype, device, generator, super_res_latents, self.super_res_scheduler, ) if device.type == "mps": # MPS does not support many interpolations image_upscaled = F.interpolate(image_small, size=[height, width]) else: interpolate_antialias = {} if "antialias" in inspect.signature(F.interpolate).parameters: interpolate_antialias["antialias"] = True image_upscaled = F.interpolate( image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias ) for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): # no classifier free guidance if i == super_res_timesteps_tensor.shape[0] - 1: unet = self.super_res_last else: unet = self.super_res_first latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) noise_pred = unet( sample=latent_model_input, timestep=t, ).sample if i + 1 == super_res_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = super_res_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 super_res_latents = self.super_res_scheduler.step( noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample image = super_res_latents # done super res # post processing image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py ================================================ # Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Union import PIL import torch from torch.nn import functional as F from transformers import ( CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...models import UNet2DConditionModel, UNet2DModel from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...schedulers import UnCLIPScheduler from ...utils import is_accelerate_available, logging, randn_tensor from .text_proj import UnCLIPTextProjModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class UnCLIPImageVariationPipeline(DiffusionPipeline): """ Pipeline to generate variations from an input image using unCLIP This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `image_encoder`. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_proj ([`UnCLIPTextProjModel`]): Utility class to prepare and combine the embeddings before they are passed to the decoder. decoder ([`UNet2DConditionModel`]): The decoder to invert the image embedding into an image. super_res_first ([`UNet2DModel`]): Super resolution unet. Used in all but the last step of the super resolution diffusion process. super_res_last ([`UNet2DModel`]): Super resolution unet. Used in the last step of the super resolution diffusion process. decoder_scheduler ([`UnCLIPScheduler`]): Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. super_res_scheduler ([`UnCLIPScheduler`]): Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. """ decoder: UNet2DConditionModel text_proj: UnCLIPTextProjModel text_encoder: CLIPTextModelWithProjection tokenizer: CLIPTokenizer feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection super_res_first: UNet2DModel super_res_last: UNet2DModel decoder_scheduler: UnCLIPScheduler super_res_scheduler: UnCLIPScheduler def __init__( self, decoder: UNet2DConditionModel, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_proj: UnCLIPTextProjModel, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, super_res_first: UNet2DModel, super_res_last: UNet2DModel, decoder_scheduler: UnCLIPScheduler, super_res_scheduler: UnCLIPScheduler, ): super().__init__() self.register_modules( decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, text_proj=text_proj, feature_extractor=feature_extractor, image_encoder=image_encoder, super_res_first=super_res_first, super_res_last=super_res_last, decoder_scheduler=decoder_scheduler, super_res_scheduler=super_res_scheduler, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): dtype = next(self.image_encoder.parameters()).dtype if image_embeddings is None: if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) return image_embeddings def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.decoder, self.text_proj, self.text_encoder, self.super_res_first, self.super_res_last, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): return self.device for module in self.decoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() def __call__( self, image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None, num_images_per_prompt: int = 1, decoder_num_inference_steps: int = 25, super_res_num_inference_steps: int = 7, generator: Optional[torch.Generator] = None, decoder_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None, image_embeddings: Optional[torch.Tensor] = None, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): The image or images to guide the image generation. If you provide a tensor, it needs to comply with the configuration of [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) `CLIPImageProcessor`. Can be left to `None` only when `image_embeddings` are passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. decoder_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality image at the expense of slower inference. super_res_num_inference_steps (`int`, *optional*, defaults to 7): The number of denoising steps for super resolution. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. decoder_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. image_embeddings (`torch.Tensor`, *optional*): Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings can be passed for tasks like image interpolations. `image` can the be left to `None`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ if image is not None: if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) else: batch_size = image.shape[0] else: batch_size = image_embeddings.shape[0] prompt = [""] * batch_size device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = decoder_guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance ) image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) # decoder text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, prompt_embeds=prompt_embeds, text_encoder_hidden_states=text_encoder_hidden_states, do_classifier_free_guidance=do_classifier_free_guidance, ) if device.type == "mps": # HACK: MPS: There is a panic when padding bool tensors, # so cast to int tensor for the pad and back to bool afterwards text_mask = text_mask.type(torch.int) decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) decoder_text_mask = decoder_text_mask.type(torch.bool) else: decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps num_channels_latents = self.decoder.config.in_channels height = self.decoder.config.sample_size width = self.decoder.config.sample_size if decoder_latents is None: decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, decoder_latents, self.decoder_scheduler, ) for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents noise_pred = self.decoder( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, class_labels=additive_clip_time_embeddings, attention_mask=decoder_text_mask, ).sample if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if i + 1 == decoder_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = decoder_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 decoder_latents = self.decoder_scheduler.step( noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample decoder_latents = decoder_latents.clamp(-1, 1) image_small = decoder_latents # done decoder # super res self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps channels = self.super_res_first.config.in_channels // 2 height = self.super_res_first.config.sample_size width = self.super_res_first.config.sample_size if super_res_latents is None: super_res_latents = self.prepare_latents( (batch_size, channels, height, width), image_small.dtype, device, generator, super_res_latents, self.super_res_scheduler, ) if device.type == "mps": # MPS does not support many interpolations image_upscaled = F.interpolate(image_small, size=[height, width]) else: interpolate_antialias = {} if "antialias" in inspect.signature(F.interpolate).parameters: interpolate_antialias["antialias"] = True image_upscaled = F.interpolate( image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias ) for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): # no classifier free guidance if i == super_res_timesteps_tensor.shape[0] - 1: unet = self.super_res_last else: unet = self.super_res_first latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) noise_pred = unet( sample=latent_model_input, timestep=t, ).sample if i + 1 == super_res_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = super_res_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 super_res_latents = self.super_res_scheduler.step( noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample image = super_res_latents # done super res # post processing image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/unclip/text_proj.py ================================================ # Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin class UnCLIPTextProjModel(ModelMixin, ConfigMixin): """ Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the decoder. For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1 """ @register_to_config def __init__( self, *, clip_extra_context_tokens: int = 4, clip_embeddings_dim: int = 768, time_embed_dim: int, cross_attention_dim, ): super().__init__() self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim)) # parameters for additional clip time embeddings self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim) # parameters for encoder hidden states self.clip_extra_context_tokens = clip_extra_context_tokens self.clip_extra_context_tokens_proj = nn.Linear( clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim ) self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance): if do_classifier_free_guidance: # Add the classifier free guidance embeddings to the image embeddings image_embeddings_batch_size = image_embeddings.shape[0] classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0) classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand( image_embeddings_batch_size, -1 ) image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) # The image embeddings batch size and the text embeddings batch size are equal assert image_embeddings.shape[0] == prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0] # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and # adding CLIP embeddings to the existing timestep embedding, ... time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds # ... and by projecting CLIP embeddings into four # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1) text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1) return text_encoder_hidden_states, additive_clip_time_embeddings ================================================ FILE: 4DoF/diffusers/pipelines/unidiffuser/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( ImageTextPipelineOutput, UniDiffuserPipeline, ) else: from .modeling_text_decoder import UniDiffuserTextDecoder from .modeling_uvit import UniDiffuserModel, UTransformer2DModel from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline ================================================ FILE: 4DoF/diffusers/pipelines/unidiffuser/modeling_text_decoder.py ================================================ from typing import Optional import numpy as np import torch from torch import nn from transformers import GPT2Config, GPT2LMHeadModel from transformers.modeling_utils import ModuleUtilsMixin from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin # Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): """ Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to generate text from the UniDiffuser image-text embedding. Parameters: prefix_length (`int`): Max number of prefix tokens that will be supplied to the model. prefix_inner_dim (`int`): The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the CLIP text encoder. prefix_hidden_dim (`int`, *optional*): Hidden dim of the MLP if we encode the prefix. vocab_size (`int`, *optional*, defaults to 50257): Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. n_positions (`int`, *optional*, defaults to 1024): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). n_embd (`int`, *optional*, defaults to 768): Dimensionality of the embeddings and hidden states. n_layer (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. n_head (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. n_inner (`int`, *optional*, defaults to None): Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd activation_function (`str`, *optional*, defaults to `"gelu"`): Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. resid_pdrop (`float`, *optional*, defaults to 0.1): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. embd_pdrop (`float`, *optional*, defaults to 0.1): The dropout ratio for the embeddings. attn_pdrop (`float`, *optional*, defaults to 0.1): The dropout ratio for the attention. layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. scale_attn_weights (`bool`, *optional*, defaults to `True`): Scale attention weights by dividing by sqrt(hidden_size).. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): Whether to additionally scale attention weights by `1 / layer_idx + 1`. reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention dot-product/softmax to float() when training with mixed precision. """ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] @register_to_config def __init__( self, prefix_length: int, prefix_inner_dim: int, prefix_hidden_dim: Optional[int] = None, vocab_size: int = 50257, # Start of GPT2 config args n_positions: int = 1024, n_embd: int = 768, n_layer: int = 12, n_head: int = 12, n_inner: Optional[int] = None, activation_function: str = "gelu_new", resid_pdrop: float = 0.1, embd_pdrop: float = 0.1, attn_pdrop: float = 0.1, layer_norm_epsilon: float = 1e-5, initializer_range: float = 0.02, scale_attn_weights: bool = True, use_cache: bool = True, scale_attn_by_inverse_layer_idx: bool = False, reorder_and_upcast_attn: bool = False, ): super().__init__() self.prefix_length = prefix_length if prefix_inner_dim != n_embd and prefix_hidden_dim is None: raise ValueError( f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and" f" `n_embd`: {n_embd} are not equal." ) self.prefix_inner_dim = prefix_inner_dim self.prefix_hidden_dim = prefix_hidden_dim self.encode_prefix = ( nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity() ) self.decode_prefix = ( nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() ) gpt_config = GPT2Config( vocab_size=vocab_size, n_positions=n_positions, n_embd=n_embd, n_layer=n_layer, n_head=n_head, n_inner=n_inner, activation_function=activation_function, resid_pdrop=resid_pdrop, embd_pdrop=embd_pdrop, attn_pdrop=attn_pdrop, layer_norm_epsilon=layer_norm_epsilon, initializer_range=initializer_range, scale_attn_weights=scale_attn_weights, use_cache=use_cache, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, reorder_and_upcast_attn=reorder_and_upcast_attn, ) self.transformer = GPT2LMHeadModel(gpt_config) def forward( self, input_ids: torch.Tensor, prefix_embeds: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ): """ Args: input_ids (`torch.Tensor` of shape `(N, max_seq_len)`): Text tokens to use for inference. prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`): Prefix embedding to preprend to the embedded tokens. attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): Attention mask for the prefix embedding. labels (`torch.Tensor`, *optional*): Labels to use for language modeling. """ embedding_text = self.transformer.transformer.wte(input_ids) hidden = self.encode_prefix(prefix_embeds) prefix_embeds = self.decode_prefix(hidden) embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1) if labels is not None: dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) labels = torch.cat((dummy_token, input_ids), dim=1) out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) if self.prefix_hidden_dim is not None: return out, hidden else: return out def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) def encode(self, prefix): return self.encode_prefix(prefix) @torch.no_grad() def generate_captions(self, features, eos_token_id, device): """ Generate captions given text embedding features. Returns list[L]. Args: features (`torch.Tensor` of shape `(B, L, D)`): Text embedding features to generate captions from. eos_token_id (`int`): The token ID of the EOS token for the text decoder model. device: Device to perform text generation on. Returns: `List[str]`: A list of strings generated from the decoder model. """ features = torch.split(features, 1, dim=0) generated_tokens = [] generated_seq_lengths = [] for feature in features: feature = self.decode_prefix(feature.to(device)) # back to the clip feature # Only support beam search for now output_tokens, seq_lengths = self.generate_beam( input_embeds=feature, device=device, eos_token_id=eos_token_id ) generated_tokens.append(output_tokens[0]) generated_seq_lengths.append(seq_lengths[0]) generated_tokens = torch.stack(generated_tokens) generated_seq_lengths = torch.stack(generated_seq_lengths) return generated_tokens, generated_seq_lengths @torch.no_grad() def generate_beam( self, input_ids=None, input_embeds=None, device=None, beam_size: int = 5, entry_length: int = 67, temperature: float = 1.0, eos_token_id: Optional[int] = None, ): """ Generates text using the given tokenizer and text prompt or token embedding via beam search. This implementation is based on the beam search implementation from the [original UniDiffuser code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). Args: eos_token_id (`int`, *optional*): The token ID of the EOS token for the text decoder model. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` must be supplied. input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): An embedded representation to directly pass to the transformer as a prefix for beam search. One of `input_ids` and `input_embeds` must be supplied. device: The device to perform beam search on. beam_size (`int`, *optional*, defaults to `5`): The number of best states to store during beam search. entry_length (`int`, *optional*, defaults to `67`): The number of iterations to run beam search. temperature (`float`, *optional*, defaults to 1.0): The temperature to use when performing the softmax over logits from the decoding model. Returns: `Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated token sequences sorted by score in descending order, and the second element is the sequence lengths corresponding to those sequences. """ # Generates text until stop_token is reached using beam search with the desired beam size. stop_token_index = eos_token_id tokens = None scores = None seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int) is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) if input_embeds is not None: generated = input_embeds else: generated = self.transformer.transformer.wte(input_ids) for i in range(entry_length): outputs = self.transformer(inputs_embeds=generated) logits = outputs.logits logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) logits = logits.softmax(-1).log() if scores is None: scores, next_tokens = logits.topk(beam_size, -1) generated = generated.expand(beam_size, *generated.shape[1:]) next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) if tokens is None: tokens = next_tokens else: tokens = tokens.expand(beam_size, *tokens.shape[1:]) tokens = torch.cat((tokens, next_tokens), dim=1) else: logits[is_stopped] = -float(np.inf) logits[is_stopped, 0] = 0 scores_sum = scores[:, None] + logits seq_lengths[~is_stopped] += 1 scores_sum_average = scores_sum / seq_lengths[:, None] scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) next_tokens_source = next_tokens // scores_sum.shape[1] seq_lengths = seq_lengths[next_tokens_source] next_tokens = next_tokens % scores_sum.shape[1] next_tokens = next_tokens.unsqueeze(1) tokens = tokens[next_tokens_source] tokens = torch.cat((tokens, next_tokens), dim=1) generated = generated[next_tokens_source] scores = scores_sum_average * seq_lengths is_stopped = is_stopped[next_tokens_source] next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) generated = torch.cat((generated, next_token_embed), dim=1) is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() if is_stopped.all(): break scores = scores / seq_lengths order = scores.argsort(descending=True) # tokens tensors are already padded to max_seq_length output_texts = [tokens[i] for i in order] output_texts = torch.stack(output_texts, dim=0) seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype) return output_texts, seq_lengths ================================================ FILE: 4DoF/diffusers/pipelines/unidiffuser/modeling_uvit.py ================================================ import math from typing import Optional, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.attention import AdaLayerNorm, FeedForward from ...models.attention_processor import Attention from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed from ...models.transformer_2d import Transformer2DModelOutput from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): logger.warning( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect." ) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # type: (torch.Tensor, float, float, float, float) -> torch.Tensor r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) class PatchEmbed(nn.Module): """2D Image to Patch Embedding""" def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, use_pos_embed=True, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None self.use_pos_embed = use_pos_embed if self.use_pos_embed: pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) def forward(self, latent): latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: latent = self.norm(latent) if self.use_pos_embed: return latent + self.pos_embed else: return latent class SkipBlock(nn.Module): def __init__(self, dim: int): super().__init__() self.skip_linear = nn.Linear(2 * dim, dim) # Use torch.nn.LayerNorm for now, following the original code self.norm = nn.LayerNorm(dim) def forward(self, x, skip): x = self.skip_linear(torch.cat([x, skip], dim=-1)) x = self.norm(x) return x # Modified to support both pre-LayerNorm and post-LayerNorm configurations # Don't support AdaLayerNormZero for now # Modified from diffusers.models.attention.BasicTransformerBlock class UTransformerBlock(nn.Module): r""" A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (:obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (:obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float32 when performing the attention calculation. norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. norm_type (`str`, defaults to `"layer_norm"`): The layer norm implementation to use. pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g. `pre_layer_norm = True`. final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", pre_layer_norm: bool = True, final_dropout: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.pre_layer_norm = pre_layer_norm if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # 1. Self-Attn self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.attn2 = None if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) else: self.norm2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) def forward( self, hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, timestep=None, cross_attention_kwargs=None, class_labels=None, ): # Pre-LayerNorm if self.pre_layer_norm: if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) else: norm_hidden_states = self.norm1(hidden_states) else: norm_hidden_states = hidden_states # 1. Self-Attention cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) # Post-LayerNorm if not self.pre_layer_norm: if self.use_ada_layer_norm: attn_output = self.norm1(attn_output, timestep) else: attn_output = self.norm1(attn_output) hidden_states = attn_output + hidden_states if self.attn2 is not None: # Pre-LayerNorm if self.pre_layer_norm: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) else: norm_hidden_states = hidden_states # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly # prepare attention mask here # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) # Post-LayerNorm if not self.pre_layer_norm: attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output) hidden_states = attn_output + hidden_states # 3. Feed-forward # Pre-LayerNorm if self.pre_layer_norm: norm_hidden_states = self.norm3(hidden_states) else: norm_hidden_states = hidden_states ff_output = self.ff(norm_hidden_states) # Post-LayerNorm if not self.pre_layer_norm: ff_output = self.norm3(ff_output) hidden_states = ff_output + hidden_states return hidden_states # Like UTransformerBlock except with LayerNorms on the residual backbone of the block # Modified from diffusers.models.attention.BasicTransformerBlock class UniDiffuserBlock(nn.Module): r""" A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104). Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (:obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (:obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float() when performing the attention calculation. norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. norm_type (`str`, defaults to `"layer_norm"`): The layer norm implementation to use. pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm (`pre_layer_norm = False`). final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", pre_layer_norm: bool = False, final_dropout: bool = True, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.pre_layer_norm = pre_layer_norm if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # 1. Self-Attn self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.attn2 = None if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) else: self.norm2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) def forward( self, hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, timestep=None, cross_attention_kwargs=None, class_labels=None, ): # Following the diffusers transformer block implementation, put the LayerNorm on the # residual backbone # Pre-LayerNorm if self.pre_layer_norm: if self.use_ada_layer_norm: hidden_states = self.norm1(hidden_states, timestep) else: hidden_states = self.norm1(hidden_states) # 1. Self-Attention cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # Following the diffusers transformer block implementation, put the LayerNorm on the # residual backbone # Post-LayerNorm if not self.pre_layer_norm: if self.use_ada_layer_norm: hidden_states = self.norm1(hidden_states, timestep) else: hidden_states = self.norm1(hidden_states) if self.attn2 is not None: # Pre-LayerNorm if self.pre_layer_norm: hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly # prepare attention mask here # 2. Cross-Attention attn_output = self.attn2( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # Post-LayerNorm if not self.pre_layer_norm: hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) # 3. Feed-forward # Pre-LayerNorm if self.pre_layer_norm: hidden_states = self.norm3(hidden_states) ff_output = self.ff(hidden_states) hidden_states = ff_output + hidden_states # Post-LayerNorm if not self.pre_layer_norm: hidden_states = self.norm3(hidden_states) return hidden_states # Modified from diffusers.models.transformer_2d.Transformer2DModel # Modify the transformer block structure to be U-Net like following U-ViT # Only supports patch-style input and torch.nn.LayerNorm currently # https://github.com/baofff/U-ViT class UTransformer2DModel(ModelMixin, ConfigMixin): """ Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion, similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`] layer and then reshaped to (b, t, d). Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input. out_channels (`int`, *optional*): The number of output channels; if `None`, defaults to `in_channels`. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups to use when performing Group Normalization. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. patch_size (`int`, *optional*, defaults to 2): The patch size to use in the patch embedding. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. use_linear_projection (int, *optional*): TODO: Not used only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used in each transformer block. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float() when performing the attention calculation. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. block_type (`str`, *optional*, defaults to `"unidiffuser"`): The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard behavior in `diffusers`.) pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm (`pre_layer_norm = False`). norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. use_patch_pos_embed (`bool`, *optional*): Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = 2, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", block_type: str = "unidiffuser", pre_layer_norm: bool = False, norm_elementwise_affine: bool = True, use_patch_pos_embed=False, ff_final_dropout: bool = False, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim # 1. Input # Only support patch input of shape (batch_size, num_channels, height, width) for now assert in_channels is not None and patch_size is not None, "Patch input requires in_channels and patch_size." assert sample_size is not None, "UTransformer2DModel over patched input must provide sample_size" # 2. Define input layers self.height = sample_size self.width = sample_size self.patch_size = patch_size self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, use_pos_embed=use_patch_pos_embed, ) # 3. Define transformers blocks # Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block, # and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.). # Quick hack to make the transformer block type configurable if block_type == "unidiffuser": block_cls = UniDiffuserBlock else: block_cls = UTransformerBlock self.transformer_in_blocks = nn.ModuleList( [ block_cls( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, final_dropout=ff_final_dropout, ) for d in range(num_layers // 2) ] ) self.transformer_mid_block = block_cls( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, final_dropout=ff_final_dropout, ) # For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs # before each transformer out_block. self.transformer_out_blocks = nn.ModuleList( [ nn.ModuleDict( { "skip": SkipBlock( inner_dim, ), "block": block_cls( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, final_dropout=ff_final_dropout, ), } ) for d in range(num_layers // 2) ] ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels # Following the UniDiffuser U-ViT implementation, we process the transformer output with # a LayerNorm layer with per-element affine params self.norm_out = nn.LayerNorm(inner_dim) def forward( self, hidden_states, encoder_hidden_states=None, timestep=None, class_labels=None, cross_attention_kwargs=None, return_dict: bool = True, hidden_states_is_embedding: bool = False, unpatchify: bool = True, ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels conditioning. cross_attention_kwargs (*optional*): Keyword arguments to supply to the cross attention layers, if used. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the transformer blocks. unpatchify (`bool`, *optional*, defaults to `True`): Whether to unpatchify the transformer output. Returns: [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 0. Check inputs if not unpatchify and return_dict: raise ValueError( f"Cannot both define `unpatchify`: {unpatchify} and `return_dict`: {return_dict} since when" f" `unpatchify` is {unpatchify} the returned output is of shape (batch_size, seq_len, hidden_dim)" " rather than (batch_size, num_channels, height, width)." ) # 1. Input if not hidden_states_is_embedding: hidden_states = self.pos_embed(hidden_states) # 2. Blocks # In ("downsample") blocks skips = [] for in_block in self.transformer_in_blocks: hidden_states = in_block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) skips.append(hidden_states) # Mid block hidden_states = self.transformer_mid_block(hidden_states) # Out ("upsample") blocks for out_block in self.transformer_out_blocks: hidden_states = out_block["skip"](hidden_states, skips.pop()) hidden_states = out_block["block"]( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output # Don't support AdaLayerNorm for now, so no conditioning/scale/shift logic hidden_states = self.norm_out(hidden_states) # hidden_states = self.proj_out(hidden_states) if unpatchify: # unpatchify height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) else: output = hidden_states if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) class UniDiffuserModel(ModelMixin, ConfigMixin): """ Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details). Parameters: text_dim (`int`): The hidden dimension of the CLIP text model used to embed images. clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts. num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input. out_channels (`int`, *optional*): The number of output channels; if `None`, defaults to `in_channels`. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups to use when performing Group Normalization. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. patch_size (`int`, *optional*, defaults to 2): The patch size to use in the patch embedding. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. use_linear_projection (int, *optional*): TODO: Not used only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used in each transformer block. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float32 when performing the attention calculation. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. block_type (`str`, *optional*, defaults to `"unidiffuser"`): The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard behavior in `diffusers`.) pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm (`pre_layer_norm = False`). norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. use_patch_pos_embed (`bool`, *optional*): Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). ff_final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. use_data_type_embedding (`bool`, *optional*): Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1 is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type` argument, which can either be `1` to use the weights trained on non-publically-available data or `0` otherwise. This argument is subsequently embedded by the data type embedding, if used. """ @register_to_config def __init__( self, text_dim: int = 768, clip_img_dim: int = 512, num_text_tokens: int = 77, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", block_type: str = "unidiffuser", pre_layer_norm: bool = False, use_timestep_embedding=False, norm_elementwise_affine: bool = True, use_patch_pos_embed=False, ff_final_dropout: bool = True, use_data_type_embedding: bool = False, ): super().__init__() # 0. Handle dimensions self.inner_dim = num_attention_heads * attention_head_dim assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size" self.sample_size = sample_size self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.patch_size = patch_size # Assume image is square... self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size) # 1. Define input layers # 1.1 Input layers for text and image input # For now, only support patch input for VAE latent image input self.vae_img_in = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=self.inner_dim, use_pos_embed=use_patch_pos_embed, ) self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim) self.text_in = nn.Linear(text_dim, self.inner_dim) # 1.2. Timestep embeddings for t_img, t_text self.timestep_img_proj = Timesteps( self.inner_dim, flip_sin_to_cos=True, downscale_freq_shift=0, ) self.timestep_img_embed = ( TimestepEmbedding( self.inner_dim, 4 * self.inner_dim, out_dim=self.inner_dim, ) if use_timestep_embedding else nn.Identity() ) self.timestep_text_proj = Timesteps( self.inner_dim, flip_sin_to_cos=True, downscale_freq_shift=0, ) self.timestep_text_embed = ( TimestepEmbedding( self.inner_dim, 4 * self.inner_dim, out_dim=self.inner_dim, ) if use_timestep_embedding else nn.Identity() ) # 1.3. Positional embedding self.num_text_tokens = num_text_tokens self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim)) self.pos_embed_drop = nn.Dropout(p=dropout) trunc_normal_(self.pos_embed, std=0.02) # 1.4. Handle data type token embeddings for UniDiffuser-V1, if necessary self.use_data_type_embedding = use_data_type_embedding if self.use_data_type_embedding: self.data_type_token_embedding = nn.Embedding(2, self.inner_dim) self.data_type_pos_embed_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim)) # 2. Define transformer blocks self.transformer = UTransformer2DModel( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, in_channels=in_channels, out_channels=out_channels, num_layers=num_layers, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attention_bias=attention_bias, sample_size=sample_size, num_vector_embeds=num_vector_embeds, patch_size=patch_size, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, block_type=block_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, use_patch_pos_embed=use_patch_pos_embed, ff_final_dropout=ff_final_dropout, ) # 3. Define output layers patch_dim = (patch_size**2) * out_channels self.vae_img_out = nn.Linear(self.inner_dim, patch_dim) self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim) self.text_out = nn.Linear(self.inner_dim, text_dim) @torch.jit.ignore def no_weight_decay(self): return {"pos_embed"} def forward( self, latent_image_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor, prompt_embeds: torch.FloatTensor, timestep_img: Union[torch.Tensor, float, int], timestep_text: Union[torch.Tensor, float, int], data_type: Optional[Union[torch.Tensor, float, int]] = 1, encoder_hidden_states=None, cross_attention_kwargs=None, ): """ Args: latent_image_embeds (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`): Latent image representation from the VAE encoder. image_embeds (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`): CLIP-embedded image representation (unsqueezed in the first dimension). prompt_embeds (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`): CLIP-embedded text representation. timestep_img (`torch.long` or `float` or `int`): Current denoising step for the image. timestep_text (`torch.long` or `float` or `int`): Current denoising step for the text. data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`): Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data, or `0` otherwise. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. cross_attention_kwargs (*optional*): Keyword arguments to supply to the cross attention layers, if used. Returns: `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text embedding. """ batch_size = latent_image_embeds.shape[0] # 1. Input # 1.1. Map inputs to shape (B, N, inner_dim) vae_hidden_states = self.vae_img_in(latent_image_embeds) clip_hidden_states = self.clip_img_in(image_embeds) text_hidden_states = self.text_in(prompt_embeds) num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1) # 1.2. Encode image timesteps to single token (B, 1, inner_dim) if not torch.is_tensor(timestep_img): timestep_img = torch.tensor([timestep_img], dtype=torch.long, device=vae_hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep_img = timestep_img * torch.ones(batch_size, dtype=timestep_img.dtype, device=timestep_img.device) timestep_img_token = self.timestep_img_proj(timestep_img) # t_img_token does not contain any weights and will always return f32 tensors # but time_embedding might be fp16, so we need to cast here. timestep_img_token = timestep_img_token.to(dtype=self.dtype) timestep_img_token = self.timestep_img_embed(timestep_img_token) timestep_img_token = timestep_img_token.unsqueeze(dim=1) # 1.3. Encode text timesteps to single token (B, 1, inner_dim) if not torch.is_tensor(timestep_text): timestep_text = torch.tensor([timestep_text], dtype=torch.long, device=vae_hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep_text = timestep_text * torch.ones(batch_size, dtype=timestep_text.dtype, device=timestep_text.device) timestep_text_token = self.timestep_text_proj(timestep_text) # t_text_token does not contain any weights and will always return f32 tensors # but time_embedding might be fp16, so we need to cast here. timestep_text_token = timestep_text_token.to(dtype=self.dtype) timestep_text_token = self.timestep_text_embed(timestep_text_token) timestep_text_token = timestep_text_token.unsqueeze(dim=1) # 1.4. Concatenate all of the embeddings together. if self.use_data_type_embedding: assert data_type is not None, "data_type must be supplied if the model uses a data type embedding" if not torch.is_tensor(data_type): data_type = torch.tensor([data_type], dtype=torch.int, device=vae_hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML data_type = data_type * torch.ones(batch_size, dtype=data_type.dtype, device=data_type.device) data_type_token = self.data_type_token_embedding(data_type).unsqueeze(dim=1) hidden_states = torch.cat( [ timestep_img_token, timestep_text_token, data_type_token, text_hidden_states, clip_hidden_states, vae_hidden_states, ], dim=1, ) else: hidden_states = torch.cat( [timestep_img_token, timestep_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states], dim=1, ) # 1.5. Prepare the positional embeddings and add to hidden states # Note: I think img_vae should always have the proper shape, so there's no need to interpolate # the position embeddings. if self.use_data_type_embedding: pos_embed = torch.cat( [self.pos_embed[:, : 1 + 1, :], self.data_type_pos_embed_token, self.pos_embed[:, 1 + 1 :, :]], dim=1 ) else: pos_embed = self.pos_embed hidden_states = hidden_states + pos_embed hidden_states = self.pos_embed_drop(hidden_states) # 2. Blocks hidden_states = self.transformer( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=None, class_labels=None, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, hidden_states_is_embedding=True, unpatchify=False, )[0] # 3. Output # Split out the predicted noise representation. if self.use_data_type_embedding: ( t_img_token_out, t_text_token_out, data_type_token_out, text_out, img_clip_out, img_vae_out, ) = hidden_states.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1) else: t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split( (1, 1, num_text_tokens, 1, num_img_tokens), dim=1 ) img_vae_out = self.vae_img_out(img_vae_out) # unpatchify height = width = int(img_vae_out.shape[1] ** 0.5) img_vae_out = img_vae_out.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out) img_vae_out = img_vae_out.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) img_clip_out = self.clip_img_out(img_clip_out) text_out = self.text_out(text_out) return img_vae_out, img_clip_out, text_out ================================================ FILE: 4DoF/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py ================================================ import inspect import warnings from dataclasses import dataclass from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, GPT2Tokenizer, ) from ...models import AutoencoderKL from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, ) from ...utils.outputs import BaseOutput from ..pipeline_utils import DiffusionPipeline from .modeling_text_decoder import UniDiffuserTextDecoder from .modeling_uvit import UniDiffuserModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image # New BaseOutput child class for joint image-text output @dataclass class ImageTextPipelineOutput(BaseOutput): """ Output class for joint image-text pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. text (`List[str]` or `List[List[str]]`) List of generated text strings of length `batch_size` or a list of list of strings whose outer list has length `batch_size`. """ images: Optional[Union[List[PIL.Image.Image], np.ndarray]] text: Optional[Union[List[str], List[List[str]]]] class UniDiffuserPipeline(DiffusionPipeline): r""" Pipeline for a bimodal image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model, which supports unconditional text and image generation, text-conditioned image generation, image-conditioned text generation, and joint image-text generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. This is part of the UniDiffuser image representation, along with the CLIP vision encoding. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to encode text prompts. image_encoder ([`CLIPVisionModel`]): UniDiffuser uses the vision portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode images as part of its image representation, along with the VAE latent representation. image_processor ([`CLIPImageProcessor`]): CLIP image processor of class [CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor), used to preprocess the image before CLIP encoding it with `image_encoder`. clip_tokenizer ([`CLIPTokenizer`]): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which is used to tokenizer a prompt before encoding it with `text_encoder`. text_decoder ([`UniDiffuserTextDecoder`]): Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser embedding. text_tokenizer ([`GPT2Tokenizer`]): Tokenizer of class [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which is used along with the `text_decoder` to decode text for text generation. unet ([`UniDiffuserModel`]): UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a [`Transformer2DModel`] with U-Net-style skip connections between transformer layers. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, image_encoder: CLIPVisionModelWithProjection, image_processor: CLIPImageProcessor, clip_tokenizer: CLIPTokenizer, text_decoder: UniDiffuserTextDecoder, text_tokenizer: GPT2Tokenizer, unet: UniDiffuserModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() if text_encoder.config.hidden_size != text_decoder.prefix_inner_dim: raise ValueError( f"The text encoder hidden size and text decoder prefix inner dim must be the same, but" f" `text_encoder.config.hidden_size`: {text_encoder.config.hidden_size} and `text_decoder.prefix_inner_dim`: {text_decoder.prefix_inner_dim}" ) self.register_modules( vae=vae, text_encoder=text_encoder, image_encoder=image_encoder, image_processor=image_processor, clip_tokenizer=clip_tokenizer, text_decoder=text_decoder, text_tokenizer=text_tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.num_channels_latents = vae.config.latent_channels self.text_encoder_seq_len = text_encoder.config.max_position_embeddings self.text_encoder_hidden_size = text_encoder.config.hidden_size self.image_encoder_projection_dim = image_encoder.config.projection_dim self.unet_resolution = unet.config.sample_size self.text_intermediate_dim = self.text_encoder_hidden_size if self.text_decoder.prefix_hidden_dim is not None: self.text_intermediate_dim = self.text_decoder.prefix_hidden_dim self.mode = None # TODO: handle safety checking? self.safety_checker = None # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.image_encoder, self.text_decoder]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae, self.image_encoder, self.text_decoder]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents): r""" Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set mode will be used. """ prompt_available = (prompt is not None) or (prompt_embeds is not None) image_available = image is not None input_available = prompt_available or image_available prompt_latents_available = prompt_latents is not None vae_latents_available = vae_latents is not None clip_latents_available = clip_latents is not None full_latents_available = latents is not None image_latents_available = vae_latents_available and clip_latents_available all_indv_latents_available = prompt_latents_available and image_latents_available if self.mode is not None: # Preferentially use the mode set by the user mode = self.mode elif prompt_available: mode = "text2img" elif image_available: mode = "img2text" else: # Neither prompt nor image supplied, infer based on availability of latents if full_latents_available or all_indv_latents_available: mode = "joint" elif prompt_latents_available: mode = "text" elif image_latents_available: mode = "img" else: # No inputs or latents available mode = "joint" # Give warnings for ambiguous cases if self.mode is None and prompt_available and image_available: logger.warning( f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually," f" defaulting to mode '{mode}'." ) if self.mode is None and not input_available: if vae_latents_available != clip_latents_available: # Exactly one of vae_latents and clip_latents is supplied logger.warning( f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none" f" are expected to be supplied. Defaulting to mode '{mode}'." ) elif not prompt_latents_available and not vae_latents_available and not clip_latents_available: # No inputs or latents supplied logger.warning( f"No inputs or latents have been supplied, and mode has not been manually set," f" defaulting to mode '{mode}'." ) return mode # Functions to manually set the mode def set_text_mode(self): r"""Manually set the generation mode to unconditional ("marginal") text generation.""" self.mode = "text" def set_image_mode(self): r"""Manually set the generation mode to unconditional ("marginal") image generation.""" self.mode = "img" def set_text_to_image_mode(self): r"""Manually set the generation mode to text-conditioned image generation.""" self.mode = "text2img" def set_image_to_text_mode(self): r"""Manually set the generation mode to image-conditioned text generation.""" self.mode = "img2text" def set_joint_mode(self): r"""Manually set the generation mode to unconditional joint image-text generation.""" self.mode = "joint" def reset_mode(self): r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs.""" self.mode = None def _infer_batch_size( self, mode, prompt, prompt_embeds, image, num_images_per_prompt, num_prompts_per_image, latents, prompt_latents, vae_latents, clip_latents, ): r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`.""" if num_images_per_prompt is None: num_images_per_prompt = 1 if num_prompts_per_image is None: num_prompts_per_image = 1 assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer" assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer" if mode in ["text2img"]: if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: # Either prompt or prompt_embeds must be present for text2img. batch_size = prompt_embeds.shape[0] multiplier = num_images_per_prompt elif mode in ["img2text"]: if isinstance(image, PIL.Image.Image): batch_size = 1 else: # Image must be available and type either PIL.Image.Image or torch.FloatTensor. # Not currently supporting something like image_embeds. batch_size = image.shape[0] multiplier = num_prompts_per_image elif mode in ["img"]: if vae_latents is not None: batch_size = vae_latents.shape[0] elif clip_latents is not None: batch_size = clip_latents.shape[0] else: batch_size = 1 multiplier = num_images_per_prompt elif mode in ["text"]: if prompt_latents is not None: batch_size = prompt_latents.shape[0] else: batch_size = 1 multiplier = num_prompts_per_image elif mode in ["joint"]: if latents is not None: batch_size = latents.shape[0] elif prompt_latents is not None: batch_size = prompt_latents.shape[0] elif vae_latents is not None: batch_size = vae_latents.shape[0] elif clip_latents is not None: batch_size = clip_latents.shape[0] else: batch_size = 1 if num_images_per_prompt == num_prompts_per_image: multiplier = num_images_per_prompt else: multiplier = min(num_images_per_prompt, num_prompts_per_image) logger.warning( f"You are using mode `{mode}` and `num_images_per_prompt`: {num_images_per_prompt} and" f" num_prompts_per_image: {num_prompts_per_image} are not equal. Using batch size equal to" f" `min(num_images_per_prompt, num_prompts_per_image) = {batch_size}." ) return batch_size, multiplier # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # self.tokenizer => self.clip_tokenizer def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.clip_tokenizer( prompt, padding="max_length", max_length=self.clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.clip_tokenizer.batch_decode( untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.clip_tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents # Add num_prompts_per_image argument, sample from autoencoder moment distribution def encode_image_vae_latents( self, image, batch_size, num_prompts_per_image, dtype, device, do_classifier_free_guidance, generator=None, ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_prompts_per_image if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) * self.vae.config.scaling_factor for i in range(batch_size) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) # Scale image_latents by the VAE's scaling factor image_latents = image_latents * self.vae.config.scaling_factor if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if do_classifier_free_guidance: uncond_image_latents = torch.zeros_like(image_latents) image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) return image_latents def encode_image_clip_latents( self, image, batch_size, num_prompts_per_image, dtype, device, generator=None, ): # Map image to CLIP embedding. if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) preprocessed_image = self.image_processor.preprocess( image, return_tensors="pt", ) preprocessed_image = preprocessed_image.to(device=device, dtype=dtype) batch_size = batch_size * num_prompts_per_image if isinstance(generator, list): image_latents = [ self.image_encoder(**preprocessed_image[i : i + 1]).image_embeds for i in range(batch_size) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.image_encoder(**preprocessed_image).image_embeds if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) return image_latents # Note that the CLIP latents are not decoded for image generation. # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Rename: decode_latents -> decode_image_latents def decode_image_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_text_latents( self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None ): # Prepare latents for the CLIP embedded prompt. shape = (batch_size * num_images_per_prompt, seq_len, hidden_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: # latents is assumed to have shace (B, L, D) latents = latents.repeat(num_images_per_prompt, 1, 1) latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument. def prepare_image_vae_latents( self, batch_size, num_prompts_per_image, num_channels_latents, height, width, dtype, device, generator, latents=None, ): shape = ( batch_size * num_prompts_per_image, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: # latents is assumed to have shape (B, C, H, W) latents = latents.repeat(num_prompts_per_image, 1, 1, 1) latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_image_clip_latents( self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None ): # Prepare latents for the CLIP embedded image. shape = (batch_size * num_prompts_per_image, 1, clip_img_dim) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: # latents is assumed to have shape (B, L, D) latents = latents.repeat(num_prompts_per_image, 1, 1) latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _split(self, x, height, width): r""" Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) and (B, 1, clip_img_dim) """ batch_size = x.shape[0] latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor img_vae_dim = self.num_channels_latents * latent_height * latent_width img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_projection_dim], dim=1) img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) return img_vae, img_clip def _combine(self, img_vae, img_clip): r""" Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1, clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim). """ img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) return torch.concat([img_vae, img_clip], dim=-1) def _split_joint(self, x, height, width): r""" Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae, img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is of shape (B, text_seq_len, text_dim). """ batch_size = x.shape[0] latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor img_vae_dim = self.num_channels_latents * latent_height * latent_width text_dim = self.text_encoder_seq_len * self.text_intermediate_dim img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1) img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim)) return img_vae, img_clip, text def _combine_joint(self, img_vae, img_clip, text): r""" Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img, clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B, C * H * W + L_img * clip_img_dim + L_text * text_dim). """ img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) text = torch.reshape(text, (text.shape[0], -1)) return torch.concat([img_vae, img_clip, text], dim=-1) def _get_noise_pred( self, mode, latents, t, prompt_embeds, img_vae, img_clip, max_timestep, data_type, guidance_scale, generator, device, height, width, ): r""" Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary. """ if mode == "joint": # Joint text-image generation img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width) img_vae_out, img_clip_out, text_out = self.unet( img_vae_latents, img_clip_latents, text_latents, timestep_img=t, timestep_text=t, data_type=data_type ) x_out = self._combine_joint(img_vae_out, img_clip_out, text_out) if guidance_scale <= 1.0: return x_out # Classifier-free guidance img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) _, _, text_out_uncond = self.unet( img_vae_T, img_clip_T, text_latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type ) img_vae_out_uncond, img_clip_out_uncond, _ = self.unet( img_vae_latents, img_clip_latents, text_T, timestep_img=t, timestep_text=max_timestep, data_type=data_type, ) x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond) return guidance_scale * x_out + (1.0 - guidance_scale) * x_out_uncond elif mode == "text2img": # Text-conditioned image generation img_vae_latents, img_clip_latents = self._split(latents, height, width) img_vae_out, img_clip_out, text_out = self.unet( img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=0, data_type=data_type ) img_out = self._combine(img_vae_out, img_clip_out) if guidance_scale <= 1.0: return img_out # Classifier-free guidance text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( img_vae_latents, img_clip_latents, text_T, timestep_img=t, timestep_text=max_timestep, data_type=data_type, ) img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond) return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond elif mode == "img2text": # Image-conditioned text generation img_vae_out, img_clip_out, text_out = self.unet( img_vae, img_clip, latents, timestep_img=0, timestep_text=t, data_type=data_type ) if guidance_scale <= 1.0: return text_out # Classifier-free guidance img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( img_vae_T, img_clip_T, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type ) return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond elif mode == "text": # Unconditional ("marginal") text generation (no CFG) img_vae_out, img_clip_out, text_out = self.unet( img_vae, img_clip, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type ) return text_out elif mode == "img": # Unconditional ("marginal") image generation (no CFG) img_vae_latents, img_clip_latents = self._split(latents, height, width) img_vae_out, img_clip_out, text_out = self.unet( img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=max_timestep, data_type=data_type, ) img_out = self._combine(img_vae_out, img_clip_out) return img_out def check_latents_shape(self, latents_name, latents, expected_shape): latents_shape = latents.shape expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension expected_shape_str = ", ".join(str(dim) for dim in expected_shape) if len(latents_shape) != expected_num_dims: raise ValueError( f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" f" {latents_shape} has {len(latents_shape)} dimensions." ) for i in range(1, expected_num_dims): if latents_shape[i] != expected_shape[i - 1]: raise ValueError( f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" f" {latents_shape} has {latents_shape[i]} != {expected_shape[i - 1]} at dimension {i}." ) def check_inputs( self, mode, prompt, image, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, latents=None, prompt_latents=None, vae_latents=None, clip_latents=None, ): # Check inputs before running the generative process. if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError( f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if mode == "text2img": if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if mode == "img2text": if image is None: raise ValueError("`img2text` mode requires an image to be provided.") # Check provided latents latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor full_latents_available = latents is not None prompt_latents_available = prompt_latents is not None vae_latents_available = vae_latents is not None clip_latents_available = clip_latents is not None if full_latents_available: individual_latents_available = ( prompt_latents is not None or vae_latents is not None or clip_latents is not None ) if individual_latents_available: logger.warning( "You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and" " `clip_latents`. The value of `latents` will override the value of any individually supplied latents." ) # Check shape of full latents img_vae_dim = self.num_channels_latents * latent_height * latent_width text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size latents_dim = img_vae_dim + self.image_encoder_projection_dim + text_dim latents_expected_shape = (latents_dim,) self.check_latents_shape("latents", latents, latents_expected_shape) # Check individual latent shapes, if present if prompt_latents_available: prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size) self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape) if vae_latents_available: vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width) self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape) if clip_latents_available: clip_latents_expected_shape = (1, self.image_encoder_projection_dim) self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape) if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available: if vae_latents.shape[0] != clip_latents.shape[0]: raise ValueError( f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:" f" {vae_latents.shape[0]} != {clip_latents.shape[0]}." ) if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available: if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]: raise ValueError( f"All of `prompt_latents`, `vae_latents`, and `clip_latents` are supplied, but their batch" f" dimensions are not equal: {prompt_latents.shape[0]} != {vae_latents.shape[0]}" f" != {clip_latents.shape[0]}." ) @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, data_type: Optional[int] = 1, num_inference_steps: int = 50, guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, num_prompts_per_image: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_latents: Optional[torch.FloatTensor] = None, vae_latents: Optional[torch.FloatTensor] = None, clip_latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead. Required for text-conditioned image generation (`text2img`) mode. image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): `Image`, or tensor representing an image batch. Required for image-conditioned text generation (`img2text`) mode. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. data_type (`int`, *optional*, defaults to 1): The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type embedding; this is added for compatibility with the UniDiffuser-v1 checkpoint. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 8.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Note that the original [UniDiffuser paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of the guidance scale `w'`, which satisfies `w = w' + 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). Used in text-conditioned image generation (`text2img`) mode. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. num_prompts_per_image (`int`, *optional*, defaults to 1): The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for joint image-text generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. Note that this is assumed to be a full set of VAE, CLIP, and text latents, if supplied, this will override the value of `prompt_latents`, `vae_latents`, and `clip_latents`. prompt_latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for text generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. vae_latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. clip_latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. Used in text-conditioned image generation (`text2img`) mode. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. Used in text-conditioned image generation (`text2img`) mode. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.unidiffuser.ImageTextPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`: [`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images, and the second element is a list of generated texts. """ # 0. Default height and width to unet height = height or self.unet_resolution * self.vae_scale_factor width = width or self.unet_resolution * self.vae_scale_factor # 1. Check inputs # Recalculate mode for each call to the pipeline. mode = self._infer_mode(prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents) self.check_inputs( mode, prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, latents, prompt_latents, vae_latents, clip_latents, ) # 2. Define call parameters batch_size, multiplier = self._infer_batch_size( mode, prompt, prompt_embeds, image, num_images_per_prompt, num_prompts_per_image, latents, prompt_latents, vae_latents, clip_latents, ) device = self._execution_device reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img" # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. # Note that this differs from the formulation in the unidiffusers paper! # do_classifier_free_guidance = guidance_scale > 1.0 # check if scheduler is in sigmas space # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") # 3. Encode input prompt, if available; otherwise prepare text latents if latents is not None: # Overwrite individual latents vae_latents, clip_latents, prompt_latents = self._split_joint(latents, height, width) if mode in ["text2img"]: # 3.1. Encode input prompt, if available assert prompt is not None or prompt_embeds is not None prompt_embeds = self._encode_prompt( prompt=prompt, device=device, num_images_per_prompt=multiplier, do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) else: # 3.2. Prepare text latent variables, if input not available prompt_embeds = self.prepare_text_latents( batch_size=batch_size, num_images_per_prompt=multiplier, seq_len=self.text_encoder_seq_len, hidden_size=self.text_encoder_hidden_size, dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision device=device, generator=generator, latents=prompt_latents, ) if reduce_text_emb_dim: prompt_embeds = self.text_decoder.encode(prompt_embeds) # 4. Encode image, if available; otherwise prepare image latents if mode in ["img2text"]: # 4.1. Encode images, if available assert image is not None, "`img2text` requires a conditioning image" # Encode image using VAE image_vae = preprocess(image) height, width = image_vae.shape[-2:] image_vae_latents = self.encode_image_vae_latents( image=image_vae, batch_size=batch_size, num_prompts_per_image=multiplier, dtype=prompt_embeds.dtype, device=device, do_classifier_free_guidance=False, # Copied from InstructPix2Pix, don't use their version of CFG generator=generator, ) # Encode image using CLIP image_clip_latents = self.encode_image_clip_latents( image=image, batch_size=batch_size, num_prompts_per_image=multiplier, dtype=prompt_embeds.dtype, device=device, generator=generator, ) # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size) image_clip_latents = image_clip_latents.unsqueeze(1) else: # 4.2. Prepare image latent variables, if input not available # Prepare image VAE latents in latent space image_vae_latents = self.prepare_image_vae_latents( batch_size=batch_size, num_prompts_per_image=multiplier, num_channels_latents=self.num_channels_latents, height=height, width=width, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=vae_latents, ) # Prepare image CLIP latents image_clip_latents = self.prepare_image_clip_latents( batch_size=batch_size, num_prompts_per_image=multiplier, clip_img_dim=self.image_encoder_projection_dim, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=clip_latents, ) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # max_timestep = timesteps[0] max_timestep = self.scheduler.config.num_train_timesteps # 6. Prepare latent variables if mode == "joint": latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds) elif mode in ["text2img", "img"]: latents = self._combine(image_vae_latents, image_clip_latents) elif mode in ["img2text", "text"]: latents = prompt_embeds # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) logger.debug(f"Scheduler extra step kwargs: {extra_step_kwargs}") # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # predict the noise residual # Also applies classifier-free guidance as described in the UniDiffuser paper noise_pred = self._get_noise_pred( mode, latents, t, prompt_embeds, image_vae_latents, image_clip_latents, max_timestep, data_type, guidance_scale, generator, device, height, width, ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 9. Post-processing gen_image = None gen_text = None if mode == "joint": image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) # Map latent VAE image back to pixel space gen_image = self.decode_image_latents(image_vae_latents) # Generate text using the text decoder output_token_list, seq_lengths = self.text_decoder.generate_captions( text_latents, self.text_tokenizer.eos_token_id, device=device ) output_list = output_token_list.cpu().numpy() gen_text = [ self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) for output, length in zip(output_list, seq_lengths) ] elif mode in ["text2img", "img"]: image_vae_latents, image_clip_latents = self._split(latents, height, width) gen_image = self.decode_image_latents(image_vae_latents) elif mode in ["img2text", "text"]: text_latents = latents output_token_list, seq_lengths = self.text_decoder.generate_captions( text_latents, self.text_tokenizer.eos_token_id, device=device ) output_list = output_token_list.cpu().numpy() gen_text = [ self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) for output, length in zip(output_list, seq_lengths) ] # 10. Convert to PIL if output_type == "pil" and gen_image is not None: gen_image = self.numpy_to_pil(gen_image) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (gen_image, gen_text) return ImageTextPipelineOutput(images=gen_image, text=gen_text) ================================================ FILE: 4DoF/diffusers/pipelines/versatile_diffusion/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, ) else: from .modeling_text_unet import UNetFlatConditionModel from .pipeline_versatile_diffusion import VersatileDiffusionPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline ================================================ FILE: 4DoF/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py ================================================ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.activations import get_activation from ...models.attention import Attention from ...models.attention_processor import ( AttentionProcessor, AttnAddedKVProcessor, AttnAddedKVProcessor2_0, AttnProcessor, ) from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput from ...utils import is_torch_version, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, num_attention_heads, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlockFlat": return DownBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "CrossAttnDownBlockFlat": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") return CrossAttnDownBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{down_block_type} is not supported.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, num_attention_heads, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlockFlat": return UpBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "CrossAttnUpBlockFlat": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") return CrossAttnUpBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{up_block_type} is not supported.") # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat class UNetFlatConditionModel(ModelMixin, ConfigMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat", ), mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn", up_block_types: Tuple[str] = ( "UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads`" " because of a naming issue as described in" " https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing" " `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( "Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:" f" {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( "Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:" f" {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( "Must provide the same number of `only_cross_attention` as `down_block_types`." f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( "Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:" f" {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:" f" {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( "Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:" f" {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:" f" {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = LinearMultiDim( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlockFlatCrossAttn": self.mid_block = UNetMidBlockFlatCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": self.mid_block = UNetMidBlockFlatSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = LinearMultiDim( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNetFlatConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires" " the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires" " the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires" " the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the" " keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires" " the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which" " requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires" " the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) class LinearMultiDim(nn.Linear): def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) if out_features is None: out_features = in_features out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) self.in_features_multidim = in_features self.out_features_multidim = out_features super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) def forward(self, input_tensor, *args, **kwargs): shape = input_tensor.shape n_dim = len(self.in_features_multidim) input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) output_tensor = super().forward(input_tensor) output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) return output_tensor class ResnetBlockFlat(nn.Module): def __init__( self, *, in_channels, out_channels=None, dropout=0.0, temb_channels=512, groups=32, groups_out=None, pre_norm=True, eps=1e-6, time_embedding_norm="default", use_in_shortcut=None, second_dim=4, **kwargs, ): super().__init__() self.pre_norm = pre_norm self.pre_norm = True in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) self.in_channels_prod = np.array(in_channels).prod() self.channels_multidim = in_channels if out_channels is not None: out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) out_channels_prod = np.array(out_channels).prod() self.out_channels_multidim = out_channels else: out_channels_prod = self.in_channels_prod self.out_channels_multidim = self.channels_multidim self.time_embedding_norm = time_embedding_norm if groups_out is None: groups_out = groups self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) if temb_channels is not None: self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) else: self.time_emb_proj = None self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) self.nonlinearity = nn.SiLU() self.use_in_shortcut = ( self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut ) self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d( self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 ) def forward(self, input_tensor, temb): shape = input_tensor.shape n_dim = len(self.channels_multidim) input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = input_tensor + hidden_states output_tensor = output_tensor.view(*shape[0:-n_dim], -1) output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) return output_tensor # Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim class DownBlockFlat(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ LinearMultiDim( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states # Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim class CrossAttnDownBlockFlat(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ LinearMultiDim( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, None, # timestep None, # class_labels cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states # Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim class UpBlockFlat(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states # Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim class CrossAttnUpBlockFlat(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, None, # timestep None, # class_labels cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states # Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] for _ in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) return hidden_states # Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatSimpleCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() self.has_cross_attention = True self.attention_head_dim = attention_head_dim resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.num_heads = in_channels // self.attention_head_dim # there is always at least one resnet resnets = [ ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ] attentions = [] for _ in range(num_layers): processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=in_channels, cross_attention_dim=in_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) # resnet hidden_states = resnet(hidden_states, temb) return hidden_states ================================================ FILE: 4DoF/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py ================================================ import inspect from typing import Callable, List, Optional, Union import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging from ..pipeline_utils import DiffusionPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionMegaSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ tokenizer: CLIPTokenizer image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModel image_encoder: CLIPVisionModel image_unet: UNet2DConditionModel text_unet: UNet2DConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers def __init__( self, tokenizer: CLIPTokenizer, image_feature_extractor: CLIPImageProcessor, text_encoder: CLIPTextModel, image_encoder: CLIPVisionModel, image_unet: UNet2DConditionModel, text_unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, image_feature_extractor=image_feature_extractor, text_encoder=text_encoder, image_encoder=image_encoder, image_unet=image_unet, text_unet=text_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) @torch.no_grad() def image_variation( self, image: Union[torch.FloatTensor, PIL.Image.Image], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): The image prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> pipe = VersatileDiffusionPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe.image_variation(image, generator=generator).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() components = {name: component for name, component in self.components.items() if name in expected_components} return VersatileDiffusionImageVariationPipeline(**components)( image=image, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, ) @torch.no_grad() def text_to_image( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionPipeline >>> import torch >>> pipe = VersatileDiffusionPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] >>> image.save("./astronaut.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() components = {name: component for name, component in self.components.items() if name in expected_components} temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) output = temp_pipeline( prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, ) # swap the attention blocks back to the original state temp_pipeline._swap_unet_attention_blocks() return output @torch.no_grad() def dual_guided( self, prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], image: Union[str, List[str]], text_to_image_strength: float = 0.5, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> text = "a red car in the sun" >>> pipe = VersatileDiffusionPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> text_to_image_strength = 0.75 >>> image = pipe.dual_guided( ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator ... ).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() components = {name: component for name, component in self.components.items() if name in expected_components} temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) output = temp_pipeline( prompt=prompt, image=image, text_to_image_strength=text_to_image_strength, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, ) temp_pipeline._revert_dual_attention() return output ================================================ FILE: 4DoF/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Tuple, Union import numpy as np import PIL import torch import torch.utils.checkpoint from transformers import ( CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. tokenizer (`transformers.BertTokenizer`): Tokenizer of class [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ tokenizer: CLIPTokenizer image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModelWithProjection image_encoder: CLIPVisionModelWithProjection image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers _optional_components = ["text_unet"] def __init__( self, tokenizer: CLIPTokenizer, image_feature_extractor: CLIPImageProcessor, text_encoder: CLIPTextModelWithProjection, image_encoder: CLIPVisionModelWithProjection, image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, image_feature_extractor=image_feature_extractor, text_encoder=text_encoder, image_encoder=image_encoder, image_unet=image_unet, text_unet=text_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None and ( "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention ): # if loading from a universal checkpoint rather than a saved dual-guided pipeline self._convert_to_dual_attention() def remove_unused_weights(self): self.register_modules(text_unet=None) def _convert_to_dual_attention(self): """ Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks from both `image_unet` and `text_unet` """ for name, module in self.image_unet.named_modules(): if isinstance(module, Transformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) image_transformer = self.image_unet.get_submodule(parent_name)[index] text_transformer = self.text_unet.get_submodule(parent_name)[index] config = image_transformer.config dual_transformer = DualTransformer2DModel( num_attention_heads=config.num_attention_heads, attention_head_dim=config.attention_head_dim, in_channels=config.in_channels, num_layers=config.num_layers, dropout=config.dropout, norm_num_groups=config.norm_num_groups, cross_attention_dim=config.cross_attention_dim, attention_bias=config.attention_bias, sample_size=config.sample_size, num_vector_embeds=config.num_vector_embeds, activation_fn=config.activation_fn, num_embeds_ada_norm=config.num_embeds_ada_norm, ) dual_transformer.transformers[0] = image_transformer dual_transformer.transformers[1] = text_transformer self.image_unet.get_submodule(parent_name)[index] = dual_transformer self.image_unet.register_to_config(dual_cross_attention=True) def _revert_dual_attention(self): """ Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` """ for name, module in self.image_unet.named_modules(): if isinstance(module, DualTransformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] self.image_unet.register_to_config(dual_cross_attention=False) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.image_unet, "_hf_hook"): return self.device for module in self.image_unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not """ def normalize_embeddings(encoder_output): embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) embeds_pooled = encoder_output.text_embeds embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) return embeds batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = normalize_embeddings(prompt_embeds) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens = [""] * batch_size max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not """ def normalize_embeddings(encoder_output): embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) embeds = self.image_encoder.visual_projection(embeds) embeds_pooled = embeds[:, 0:1] embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) return embeds batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) image_embeddings = self.image_encoder(pixel_values) image_embeddings = normalize_embeddings(image_embeddings) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) negative_prompt_embeds = self.image_encoder(pixel_values) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and conditional embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, prompt, image, height, width, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")): for name, module in self.image_unet.named_modules(): if isinstance(module, DualTransformer2DModel): module.mix_ratio = mix_ratio for i, type in enumerate(condition_types): if type == "text": module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings module.transformer_index_for_condition[i] = 1 # use the second (text) transformer else: module.condition_lengths[i] = 257 module.transformer_index_for_condition[i] = 0 # use the first (image) transformer @torch.no_grad() def __call__( self, prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], image: Union[str, List[str]], text_to_image_strength: float = 0.5, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionDualGuidedPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> text = "a red car in the sun" >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe.remove_unused_weights() >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> text_to_image_strength = 0.75 >>> image = pipe( ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator ... ).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.image_unet.config.sample_size * self.vae_scale_factor width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, image, height, width, callback_steps) # 2. Define call parameters prompt = [prompt] if not isinstance(prompt, list) else prompt image = [image] if not isinstance(image, list) else image batch_size = len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompts prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1) prompt_types = ("text", "image") # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, dual_prompt_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Combine the attention blocks of the image and text UNets self.set_transformer_params(text_to_image_strength, prompt_types) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch import torch.utils.checkpoint from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. tokenizer (`transformers.BertTokenizer`): Tokenizer of class [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ image_feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection image_unet: UNet2DConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers def __init__( self, image_feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, image_unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( image_feature_extractor=image_feature_extractor, image_encoder=image_encoder, image_unet=image_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.image_unet, "_hf_hook"): return self.device for module in self.image_unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ def normalize_embeddings(encoder_output): embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) embeds = self.image_encoder.visual_projection(embeds) embeds_pooled = embeds[:, 0:1] embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) return embeds if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: prompt = list(prompt) batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) image_embeddings = self.image_encoder(pixel_values) image_embeddings = normalize_embeddings(image_embeddings) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_images: List[str] if negative_prompt is None: uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, PIL.Image.Image): uncond_images = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_images = negative_prompt uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) negative_prompt_embeds = self.image_encoder(pixel_values) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and conditional embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): The image prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionImageVariationPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe(image, generator=generator).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.image_unet.config.sample_size * self.vae_scale_factor width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt image_embeddings = self._encode_prompt( image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import torch import torch.utils.checkpoint from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. tokenizer (`transformers.BertTokenizer`): Tokenizer of class [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ tokenizer: CLIPTokenizer image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModelWithProjection image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers _optional_components = ["text_unet"] def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, image_unet=image_unet, text_unet=text_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None: self._swap_unet_attention_blocks() def _swap_unet_attention_blocks(self): """ Swap the `Transformer2DModel` blocks between the image and text UNets """ for name, module in self.image_unet.named_modules(): if isinstance(module, Transformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( self.text_unet.get_submodule(parent_name)[index], self.image_unet.get_submodule(parent_name)[index], ) def remove_unused_weights(self): self.register_modules(text_unet=None) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.image_unet, "_hf_hook"): return self.device for module in self.image_unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ def normalize_embeddings(encoder_output): embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) embeds_pooled = encoder_output.text_embeds embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) return embeds batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = normalize_embeddings(prompt_embeds) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionTextToImagePipeline >>> import torch >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe.remove_unused_weights() >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] >>> image.save("./astronaut.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.image_unet.config.sample_size * self.vae_scale_factor width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 4DoF/diffusers/pipelines/vq_diffusion/__init__.py ================================================ from ...utils import is_torch_available, is_transformers_available if is_transformers_available() and is_torch_available(): from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline ================================================ FILE: 4DoF/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py ================================================ # Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTokenizer from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin, Transformer2DModel, VQModel from ...schedulers import VQDiffusionScheduler from ...utils import logging from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): """ Utility class for storing learned text embeddings for classifier free sampling """ @register_to_config def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): super().__init__() self.learnable = learnable if self.learnable: assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" assert length is not None, "learnable=True requires `length` to be set" embeddings = torch.zeros(length, hidden_size) else: embeddings = None self.embeddings = torch.nn.Parameter(embeddings) class VQDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using VQ Diffusion This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vqvae ([`VQModel`]): Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. VQ Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). transformer ([`Transformer2DModel`]): Conditional transformer to denoise the encoded image latents. scheduler ([`VQDiffusionScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ vqvae: VQModel text_encoder: CLIPTextModel tokenizer: CLIPTokenizer transformer: Transformer2DModel learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings scheduler: VQDiffusionScheduler def __init__( self, vqvae: VQModel, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, transformer: Transformer2DModel, scheduler: VQDiffusionScheduler, learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, ): super().__init__() self.register_modules( vqvae=vqvae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, ) def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) text_input_ids = text_inputs.input_ids if text_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0] # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. # While CLIP does normalize the pooled output of the text transformer when combining # the image and text embeddings, CLIP does not directly normalize the last hidden state. # # CLIP normalizing the pooled output. # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) # duplicate text embeddings for each generation per prompt prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: if self.learned_classifier_free_sampling_embeddings.learnable: negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1) else: uncond_tokens = [""] * batch_size max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # See comment for normalizing text embeddings negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], num_inference_steps: int = 100, guidance_scale: float = 5.0, truncation_rate: float = 1.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor` of shape (batch), *optional*): Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated of completely masked latent pixels. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # get the initial completely masked latents unless the user supplied it latents_shape = (batch_size, self.transformer.num_latent_pixels) if latents is None: mask_class = self.transformer.num_vector_embeds - 1 latents = torch.full(latents_shape, mask_class).to(self.device) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): raise ValueError( "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," f" {self.transformer.num_vector_embeds - 1} (inclusive)." ) latents = latents.to(self.device) # set timesteps self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps_tensor = self.scheduler.timesteps.to(self.device) sample = latents for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the sample if we are doing classifier free guidance latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample # predict the un-noised image # model_output == `log_p_x_0` model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample if do_classifier_free_guidance: model_output_uncond, model_output_text = model_output.chunk(2) model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) model_output = self.truncate(model_output, truncation_rate) # remove `log(0)`'s (`-inf`s) model_output = model_output.clamp(-70) # compute the previous noisy sample x_t -> x_t-1 sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, sample) embedding_channels = self.vqvae.config.vq_embed_dim embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) image = self.vqvae.decode(embeddings, force_not_quantize=True).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor: """ Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. """ sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) sorted_p_x_0 = torch.exp(sorted_log_p_x_0) keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate # Ensure that at least the largest probability is not zeroed out all_true = torch.full_like(keep_mask[:, 0:1, :], True) keep_mask = torch.cat((all_true, keep_mask), dim=1) keep_mask = keep_mask[:, :-1, :] keep_mask = keep_mask.gather(1, indices.argsort(1)) rv = log_p_x_0.clone() rv[~keep_mask] = -torch.inf # -inf = log(0) return rv ================================================ FILE: 4DoF/diffusers/schedulers/__init__.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..utils import ( OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available, is_torchsde_available, ) try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_deis_multistep import DEISMultistepScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_repaint import RePaintScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_unclip import UnCLIPScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_flax_objects import * # noqa F403 else: from .scheduling_ddim_flax import FlaxDDIMScheduler from .scheduling_ddpm_flax import FlaxDDPMScheduler from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler from .scheduling_utils_flax import ( FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, ) try: if not (is_torch_available() and is_scipy_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 else: from .scheduling_lms_discrete import LMSDiscreteScheduler try: if not (is_torch_available() and is_torchsde_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 else: from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_consistency_models.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, randn_tensor from .scheduling_utils import SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class CMStochasticIterativeSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.FloatTensor class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): """ Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in the paper [1]. [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" https://arxiv.org/pdf/2303.01469 [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. sigma_min (`float`): Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the original implementation. sigma_max (`float`): Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the original implementation. sigma_data (`float`): The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the original implementation, which is also the original value suggested in the EDM paper. s_noise (`float`): The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. This was set to 1.0 in the original implementation. rho (`float`): The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This was set to 7.0 in the original implementation, which is also the original value suggested in the EDM paper. clip_denoised (`bool`): Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`. timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*): Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing order. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 40, sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, s_noise: float = 1.0, rho: float = 7.0, clip_denoised: bool = True, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max ramp = np.linspace(0, 1, num_train_timesteps) sigmas = self._convert_to_karras(ramp) timesteps = self.sigma_to_t(sigmas) # setable values self.num_inference_steps = None self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps) self.custom_timesteps = False self.is_scale_input_called = False def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() return indices.item() def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`, following the EDM model. Args: sample (`torch.FloatTensor`): input sample timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: scaled input sample """ # Get sigma corresponding to timestep if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_idx = self.index_for_timestep(timestep) sigma = self.sigmas[step_idx] sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5) self.is_scale_input_called = True return sample def sigma_to_t(self, sigmas: Union[float, np.ndarray]): """ Gets scaled timesteps from the Karras sigmas, for input to the consistency model. Args: sigmas (`float` or `np.ndarray`): single Karras sigma or array of Karras sigmas Returns: `float` or `np.ndarray`: scaled input timestep or scaled input timestep array """ if not isinstance(sigmas, np.ndarray): sigmas = np.array(sigmas, dtype=np.float64) timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44) return timesteps def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, optional): custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` must be `None`. """ if num_inference_steps is None and timesteps is None: raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") # Follow DDPMScheduler custom timesteps logic if timesteps is not None: for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`timesteps` must be in descending order.") if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) self.custom_timesteps = True else: if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 num_train_timesteps = self.config.num_train_timesteps ramp = timesteps[::-1].copy() ramp = ramp / (num_train_timesteps - 1) sigmas = self._convert_to_karras(ramp) timesteps = self.sigma_to_t(sigmas) sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) # Modified _convert_to_karras implementation that takes in ramp as argument def _convert_to_karras(self, ramp): """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = self.config.sigma_min sigma_max: float = self.config.sigma_max rho = self.config.rho min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def get_scalings(self, sigma): sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out def get_scalings_for_boundary_condition(self, sigma): """ Gets the scalings used in the consistency model parameterization, following Appendix C of the original paper. This enforces the consistency model boundary condition. Note that `epsilon` in the equations for c_skip and c_out is set to sigma_min. Args: sigma (`torch.FloatTensor`): The current sigma in the Karras sigma schedule. Returns: `tuple`: A two-element tuple where c_skip (which weights the current sample) is the first element and c_out (which weights the consistency model output) is the second element. """ sigma_min = self.config.sigma_min sigma_data = self.config.sigma_data c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator (`torch.Generator`, *optional*): Random number generator. return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class Returns: [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" f" `{self.__class__}.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) sigma_min = self.config.sigma_min sigma_max = self.config.sigma_max step_index = self.index_for_timestep(timestep) # sigma_next corresponds to next_t in original implementation sigma = self.sigmas[step_index] if step_index + 1 < self.config.num_train_timesteps: sigma_next = self.sigmas[step_index + 1] else: # Set sigma_next to sigma_min sigma_next = self.sigmas[-1] # Get scalings for boundary conditions c_skip, c_out = self.get_scalings_for_boundary_condition(sigma) # 1. Denoise model output using boundary conditions denoised = c_out * model_output + c_skip * sample if self.config.clip_denoised: denoised = denoised.clamp(-1, 1) # 2. Sample z ~ N(0, s_noise^2 * I) # Noise is not used for onestep sampling. if len(self.timesteps) > 1: noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) else: noise = torch.zeros_like(model_output) z = noise * self.config.s_noise sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) # 3. Return noisy sample # tau = sigma_hat, eps = sigma_min prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 if not return_dict: return (prev_sample,) return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_ddim.py ================================================ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.FloatTensor`): the betas that the scheduler is being initialized with. Returns: `torch.FloatTensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DDIMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. rescale_betas_zero_snr (`bool`, default `False`): whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). This can enable the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. generator: random number generator. variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://arxiv.org/abs/2210.05559) return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_ddim_flax.py ================================================ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, get_velocity_common, ) @flax.struct.dataclass class DDIMSchedulerState: common: CommonSchedulerState final_alpha_cumprod: jnp.ndarray # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod def create( cls, common: CommonSchedulerState, final_alpha_cumprod: jnp.ndarray, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, ): return cls( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) @dataclass class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput): state: DDIMSchedulerState class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. `v-prediction` is not supported for this scheduler. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: if common is None: common = CommonSchedulerState.create(self) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. final_alpha_cumprod = ( jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] ) # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return DDIMSchedulerState.create( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def scale_model_input( self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def set_timesteps( self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, ) def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def step( self, state: DDIMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, eta: float = 0.0, return_dict: bool = True, ) -> Union[FlaxDDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class Returns: [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps alphas_cumprod = state.common.alphas_cumprod final_alpha_cumprod = state.final_alpha_cumprod # 2. compute alphas, betas alpha_prod_t = alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(state, timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if not return_dict: return (prev_sample, state) return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, state: DDIMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def get_velocity( self, state: DDIMSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return get_velocity_common(state.common, sample, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_ddim_inverse.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import BaseOutput, deprecate @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): """ DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`]. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_zero (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `0`, otherwise it uses the value of alpha at step `num_train_timesteps - 1`. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_zero=False`, to make the last step use step `num_train_timesteps - 1` for the previous alpha product. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_zero: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", clip_sample_range: float = 1.0, **kwargs, ): if kwargs.get("set_alpha_to_one", None) is not None: deprecation_message = ( "The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead." ) deprecate("set_alpha_to_one", "1.0.0", deprecation_message, standard_warn=False) set_alpha_to_zero = kwargs["set_alpha_to_one"] if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in inverted ddim, we are looking into the next alphas_cumprod # For the final step, there is no next alphas_cumprod, and the index is out of bounds # `set_alpha_to_zero` decides whether we set this parameter simply to zero # in this case, self.step() just output the predicted noise # or whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps += self.config.steps_offset def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: # 1. get previous step value (=t+1) prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas # change original implementation to exactly match noise levels for analogous forward process alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = ( self.alphas_cumprod[prev_timestep] if prev_timestep < self.config.num_train_timesteps else self.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if not return_dict: return (prev_sample, pred_original_sample) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_ddim_parallel.py ================================================ # Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput class DDIMParallelSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.FloatTensor`): the betas that the scheduler is being initialized with. Returns: `torch.FloatTensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. rescale_betas_zero_snr (`bool`, default `False`): whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). This can enable the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 _is_ode_scheduler = True @register_to_config # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.__init__ def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def _get_variance(self, timestep, prev_timestep=None): if prev_timestep is None: prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def _batch_get_variance(self, t, prev_t): alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[DDIMParallelSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. generator: random number generator. variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://arxiv.org/abs/2210.05559) return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def batch_step_no_noise( self, model_output: torch.FloatTensor, timesteps: List[int], sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, ) -> torch.FloatTensor: """ Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise is pre-sampled by the pipeline. Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timesteps (`List[int]`): current discrete timesteps in the diffusion chain. This is now a list of integers. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. Returns: `torch.FloatTensor`: sample tensor at previous timestep. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) assert eta == 0.0 # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) t = timesteps prev_t = t - self.config.num_train_timesteps // self.num_inference_steps t = t.view(-1, *([1] * (model_output.ndim - 1))) prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) # 1. compute alphas, betas self.alphas_cumprod = self.alphas_cumprod.to(model_output.device) self.final_alpha_cumprod = self.final_alpha_cumprod.to(model_output.device) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._batch_get_variance(t, prev_t).to(model_output.device).view(*alpha_prod_t_prev.shape) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction return prev_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_ddpm.py ================================================ # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass class DDPMSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DDPMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.custom_timesteps = False self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`Optional[int]`): the number of diffusion steps used when generating samples with a pre-trained model. If passed, then `timesteps` must be `None`. device (`str` or `torch.device`, optional): the device to which the timesteps are moved to. custom_timesteps (`List[int]`, optional): custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` must be `None`. """ if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") if timesteps is not None: for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`custom_timesteps` must be in descending order.") if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) self.custom_timesteps = True else: if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps self.custom_timesteps = False # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): prev_t = self.previous_timestep(t) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t # we always take the log of variance, so clamp it to ensure it's not 0 variance = torch.clamp(variance, min=1e-20) if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": variance = variance # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(variance) variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = current_beta_t elif variance_type == "fixed_large_log": # Glide max_log variance = torch.log(current_beta_t) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = torch.log(variance) max_log = torch.log(current_beta_t) frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep prev_t = self.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise variance = 0 if t > 0: device = model_output.device variance_noise = randn_tensor( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise elif self.variance_type == "learned_range": variance = self._get_variance(t, predicted_variance=predicted_variance) variance = torch.exp(0.5 * variance) * variance_noise else: variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample,) return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps def previous_timestep(self, timestep): if self.custom_timesteps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: num_inference_steps = ( self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps ) prev_t = timestep - self.config.num_train_timesteps // num_inference_steps return prev_t ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_ddpm_flax.py ================================================ # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, get_velocity_common, ) @flax.struct.dataclass class DDPMSchedulerState: common: CommonSchedulerState # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) @dataclass class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput): state: DDPMSchedulerState class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. `v-prediction` is not supported for this scheduler. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: if common is None: common = CommonSchedulerState.create(self) # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return DDPMSchedulerState.create( common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def scale_model_input( self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def set_timesteps( self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> DDPMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`DDIMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, ) def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, variance_type=None): alpha_prod_t = state.common.alphas_cumprod[t] alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t] if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": variance = jnp.clip(variance, a_min=1e-20) # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = jnp.log(jnp.clip(variance, a_min=1e-20)) elif variance_type == "fixed_large": variance = state.common.betas[t] elif variance_type == "fixed_large_log": # Glide max_log variance = jnp.log(state.common.betas[t]) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = variance max_log = state.common.betas[t] frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def step( self, state: DDPMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, key: Optional[jax.random.KeyArray] = None, return_dict: bool = True, ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. key (`jax.random.KeyArray`): a PRNG key. return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep if key is None: key = jax.random.PRNGKey(0) if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = state.common.alphas_cumprod[t] alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " " for the FlaxDDPMScheduler." ) # 3. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = jnp.clip(pred_original_sample, -1, 1) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise def random_variance(): split_key = jax.random.split(key, num=1) noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise variance = jnp.where(t > 0, random_variance(), jnp.zeros(model_output.shape, dtype=self.dtype)) pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample, state) return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) def add_noise( self, state: DDPMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def get_velocity( self, state: DDPMSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return get_velocity_common(state.common, sample, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_ddpm_parallel.py ================================================ # Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput class DDPMParallelSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 _is_ode_scheduler = False @register_to_config # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.__init__ def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.custom_timesteps = False self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.set_timesteps def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`Optional[int]`): the number of diffusion steps used when generating samples with a pre-trained model. If passed, then `timesteps` must be `None`. device (`str` or `torch.device`, optional): the device to which the timesteps are moved to. custom_timesteps (`List[int]`, optional): custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` must be `None`. """ if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") if timesteps is not None: for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`custom_timesteps` must be in descending order.") if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) self.custom_timesteps = True else: if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps self.custom_timesteps = False # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance def _get_variance(self, t, predicted_variance=None, variance_type=None): prev_t = self.previous_timestep(t) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t # we always take the log of variance, so clamp it to ensure it's not 0 variance = torch.clamp(variance, min=1e-20) if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": variance = variance # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(variance) variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = current_beta_t elif variance_type == "fixed_large_log": # Glide max_log variance = torch.log(current_beta_t) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = torch.log(variance) max_log = torch.log(current_beta_t) frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True, ) -> Union[DDPMParallelSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep prev_t = self.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise variance = 0 if t > 0: device = model_output.device variance_noise = randn_tensor( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise elif self.variance_type == "learned_range": variance = self._get_variance(t, predicted_variance=predicted_variance) variance = torch.exp(0.5 * variance) * variance_noise else: variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample,) return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def batch_step_no_noise( self, model_output: torch.FloatTensor, timesteps: List[int], sample: torch.FloatTensor, ) -> torch.FloatTensor: """ Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise is pre-sampled by the pipeline. Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timesteps (`List[int]`): current discrete timesteps in the diffusion chain. This is now a list of integers. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: sample tensor at previous timestep. """ t = timesteps num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps prev_t = t - self.config.num_train_timesteps // num_inference_steps t = t.view(-1, *([1] * (model_output.ndim - 1))) prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: pass # 1. compute alphas, betas self.alphas_cumprod = self.alphas_cumprod.to(model_output.device) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMParallelScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample return pred_prev_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): if self.custom_timesteps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: num_inference_steps = ( self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps ) prev_t = timestep - self.config.num_train_timesteps // num_inference_steps return prev_t ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_deis_multistep.py ================================================ # Copyright 2023 FLAIR Lab and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): """ DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver. More variants of DEIS can be found in https://github.com/qsh-zh/deis. Currently, we support the log-rho multistep DEIS. We recommend to use `solver_order=2 / 3` while `solver_order=1` reduces to DDIM. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set `thresholding=True` to use the dynamic thresholding. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DEIS; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, or `v-prediction`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` algorithm_type (`str`, default `deis`): the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in the future lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "deis", solver_type: str = "logrho", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DEIS if algorithm_type not in ["deis"]: if algorithm_type in ["dpmsolver", "dpmsolver++"]: self.register_to_config(algorithm_type="deis") else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["logrho"]: if solver_type in ["midpoint", "heun", "bh1", "bh2"]: self.register_to_config(solver_type="logrho") else: raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) self.sigmas = torch.from_numpy(sigmas) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ Convert the model output to the corresponding type that the algorithm DEIS needs. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ if self.config.prediction_type == "epsilon": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DEISMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) if self.config.algorithm_type == "deis": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] return (sample - alpha_t * x0_pred) / sigma_t else: raise NotImplementedError("only support log-rho multistep deis now") def deis_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the first-order DEIS (equivalent to DDIM). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s if self.config.algorithm_type == "deis": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output else: raise NotImplementedError("only support log-rho multistep deis now") return x_t def multistep_deis_second_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the second-order multistep DEIS. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1] sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1] rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 if self.config.algorithm_type == "deis": def ind_fn(t, b, c): # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}] return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c)) coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1) coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0) x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1) return x_t else: raise NotImplementedError("only support log-rho multistep deis now") def multistep_deis_third_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the third-order multistep DEIS. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2] sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2] rho_t, rho_s0, rho_s1, rho_s2 = ( sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1, simga_s2 / alpha_s2, ) if self.config.algorithm_type == "deis": def ind_fn(t, b, c, d): # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}] numerator = t * ( np.log(c) * (np.log(d) - np.log(t) + 1) - np.log(d) * np.log(t) + np.log(d) + np.log(t) ** 2 - 2 * np.log(t) + 2 ) denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d)) return numerator / denominator coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2) coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0) coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1) x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2) return x_t else: raise NotImplementedError("only support log-rho multistep deis now") def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the multistep DEIS. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] lower_order_final = ( (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, timestep, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_deis_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample ) else: timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_deis_third_order_update( self.model_outputs, timestep_list, prev_timestep, sample ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_dpmsolver_multistep.py ================================================ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality samples, and it can generate quite good samples even in only 10 steps. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the second-order `sde-dpmsolver++`. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. lambda_min_clipped (`float`, default `-inf`): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) self.sigmas = torch.from_numpy(sigmas) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: epsilon = model_output[:, :3] else: epsilon = model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the first-order DPM-Solver (equivalent to DDIM). See https://arxiv.org/abs/2206.00927 for the detailed derivation. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None x_t = ( (alpha_t / alpha_s) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the second-order multistep DPM-Solver. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * (torch.exp(h) - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the third-order multistep DPM-Solver. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2], ) alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the multistep DPM-Solver. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] lower_order_final = ( (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, timestep, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) else: noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update( model_output, timestep, prev_timestep, sample, noise=noise ) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample, noise=noise ) else: timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update( self.model_outputs, timestep_list, prev_timestep, sample ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py ================================================ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver from dataclasses import dataclass from typing import List, Optional, Tuple, Union import flax import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, ) @flax.struct.dataclass class DPMSolverMultistepSchedulerState: common: CommonSchedulerState alpha_t: jnp.ndarray sigma_t: jnp.ndarray lambda_t: jnp.ndarray # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None # running values model_outputs: Optional[jnp.ndarray] = None lower_order_nums: Optional[jnp.int32] = None prev_timestep: Optional[jnp.int32] = None cur_sample: Optional[jnp.ndarray] = None @classmethod def create( cls, common: CommonSchedulerState, alpha_t: jnp.ndarray, sigma_t: jnp.ndarray, lambda_t: jnp.ndarray, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, ): return cls( common=common, alpha_t=alpha_t, sigma_t=sigma_t, lambda_t=lambda_t, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) @dataclass class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput): state: DPMSolverMultistepSchedulerState class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality samples, and it can generate quite good samples even in only 10 steps. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, or `v-prediction`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: if common is None: common = CommonSchedulerState.create(self) # Currently we only support VP-type noise schedule alpha_t = jnp.sqrt(common.alphas_cumprod) sigma_t = jnp.sqrt(1 - common.alphas_cumprod) lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) # settings for DPM-Solver if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]: raise NotImplementedError(f"{self.config.algorithm_type} does is not implemented for {self.__class__}") if self.config.solver_type not in ["midpoint", "heun"]: raise NotImplementedError(f"{self.config.solver_type} does is not implemented for {self.__class__}") # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return DPMSolverMultistepSchedulerState.create( common=common, alpha_t=alpha_t, sigma_t=sigma_t, lambda_t=lambda_t, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def set_timesteps( self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple ) -> DPMSolverMultistepSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. shape (`Tuple`): the shape of the samples to be generated. """ timesteps = ( jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .astype(jnp.int32) ) # initial running values model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) lower_order_nums = jnp.int32(0) prev_timestep = jnp.int32(-1) cur_sample = jnp.zeros(shape, dtype=self.dtype) return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, model_outputs=model_outputs, lower_order_nums=lower_order_nums, prev_timestep=prev_timestep, cur_sample=cur_sample, ) def convert_model_output( self, state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." ) if self.config.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 dynamic_max_val = jnp.percentile( jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) ) dynamic_max_val = jnp.maximum( dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) ) x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." ) def dpm_solver_first_order_update( self, state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ One step for the first-order DPM-Solver (equivalent to DDIM). See https://arxiv.org/abs/2206.00927 for the detailed derivation. Args: model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0 = prev_timestep, timestep m0 = model_output lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0] alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0] h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0 return x_t def multistep_dpm_solver_second_order_update( self, state: DPMSolverMultistepSchedulerState, model_output_list: jnp.ndarray, timestep_list: List[int], prev_timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ One step for the second-order multistep DPM-Solver. Args: model_output_list (`List[jnp.ndarray]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 ) return x_t def multistep_dpm_solver_third_order_update( self, state: DPMSolverMultistepSchedulerState, model_output_list: jnp.ndarray, timestep_list: List[int], prev_timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ One step for the third-order multistep DPM-Solver. Args: model_output_list (`List[jnp.ndarray]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1], state.lambda_t[s2], ) alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def step( self, state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class Returns: [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) (step_index,) = jnp.where(state.timesteps == timestep, size=1) step_index = step_index[0] prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1]) model_output = self.convert_model_output(state, model_output, timestep, sample) model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0) model_outputs_new = model_outputs_new.at[-1].set(model_output) state = state.replace( model_outputs=model_outputs_new, prev_timestep=prev_timestep, cur_sample=sample, ) def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: return self.dpm_solver_first_order_update( state, state.model_outputs[-1], state.timesteps[step_index], state.prev_timestep, state.cur_sample, ) def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]]) return self.multistep_dpm_solver_second_order_update( state, state.model_outputs, timestep_list, state.prev_timestep, state.cur_sample, ) def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: timestep_list = jnp.array( [ state.timesteps[step_index - 2], state.timesteps[step_index - 1], state.timesteps[step_index], ] ) return self.multistep_dpm_solver_third_order_update( state, state.model_outputs, timestep_list, state.prev_timestep, state.cur_sample, ) step_2_output = step_2(state) step_3_output = step_3(state) if self.config.solver_order == 2: return step_2_output elif self.config.lower_order_final and len(state.timesteps) < 15: return jax.lax.select( state.lower_order_nums < 2, step_2_output, jax.lax.select( step_index == len(state.timesteps) - 2, step_2_output, step_3_output, ), ) else: return jax.lax.select( state.lower_order_nums < 2, step_2_output, step_3_output, ) step_1_output = step_1(state) step_23_output = step_23(state) if self.config.solver_order == 1: prev_sample = step_1_output elif self.config.lower_order_final and len(state.timesteps) < 15: prev_sample = jax.lax.select( state.lower_order_nums < 1, step_1_output, jax.lax.select( step_index == len(state.timesteps) - 1, step_1_output, step_23_output, ), ) else: prev_sample = jax.lax.select( state.lower_order_nums < 1, step_1_output, step_23_output, ) state = state.replace( lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order), ) if not return_dict: return (prev_sample, state) return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) def scale_model_input( self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def add_noise( self, state: DPMSolverMultistepSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py ================================================ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): """ DPMSolverMultistepInverseScheduler is the reverse scheduler of [`DPMSolverMultistepScheduler`]. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. lambda_min_clipped (`float`, default `-inf`): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self.use_karras_sigmas = use_karras_sigmas def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped) self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx timesteps = ( np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) ) if self.use_karras_sigmas: sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: epsilon = model_output[:, :3] else: epsilon = model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the first-order DPM-Solver (equivalent to DDIM). See https://arxiv.org/abs/2206.00927 for the detailed derivation. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None x_t = ( (alpha_t / alpha_s) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the second-order multistep DPM-Solver. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * (torch.exp(h) - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the third-order multistep DPM-Solver. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2], ) alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the multistep DPM-Solver. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() prev_timestep = ( self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] ) lower_order_final = ( (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, timestep, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) else: noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update( model_output, timestep, prev_timestep, sample, noise=noise ) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample, noise=noise ) else: timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update( self.model_outputs, timestep_list, prev_timestep, sample ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_dpmsolver_sde.py ================================================ # Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np import torch import torchsde from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" def __init__(self, x, t0, t1, seed=None, **kwargs): t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get("w0", torch.zeros_like(x)) if seed is None: seed = torch.randint(0, 2**63 - 1, []).item() self.batched = True try: assert len(seed) == x.shape[0] w0 = w0[0] except TypeError: seed = [seed] self.batched = False self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] @staticmethod def sort(a, b): return (a, b, 1) if a < b else (b, a, -1) def __call__(self, t0, t1): t0, t1, sign = self.sort(t0, t1) w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) return w if self.batched else w[0] class BrownianTreeNoiseSampler: """A noise sampler backed by a torchsde.BrownianTree. Args: x (Tensor): The tensor whose shape, device and dtype to use to generate random samples. sigma_min (float): The low end of the valid interval. sigma_max (float): The high end of the valid interval. seed (int or List[int]): The random seed. If a list of seeds is supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each with its own seed. transform (callable): A function that maps sigma to the sampler's internal timestep. """ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): self.transform = transform t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) self.tree = BatchedBrownianTree(x, t0, t1, seed) def __call__(self, sigma, sigma_next): t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) return self.tree(t0, t1) / (t1 - t0).abs().sqrt() # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): """ Implements Stochastic Sampler (Algorithm 2) from Karras et al. (2022). Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/41b4cb6df0506694a7776af31349acf082bf6091/k_diffusion/sampling.py#L543 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed will be generated. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas self.noise_sampler = None self.noise_sampler_seed = noise_sampler_seed # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) if len(self._index_counter) == 0: pos = 1 if len(indices) > 1 else 0 else: timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep pos = self._index_counter[timestep_int] return indices[pos].item() @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], ) -> torch.FloatTensor: """ Args: Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ step_index = self.index_for_timestep(timestep) sigma = self.sigmas[step_index] sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma sample = sample / ((sigma_input**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) timesteps = torch.from_numpy(timesteps) second_order_timesteps = torch.from_numpy(second_order_timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) timesteps[1::2] = second_order_timesteps if str(device).startswith("mps"): # mps does not support float64 self.timesteps = timesteps.to(device, dtype=torch.float32) else: self.timesteps = timesteps.to(device=device) # empty first order variables self.sample = None self.mid_point_sigma = None # for exp beta schedules, such as the one for `pipeline_shap_e.py` # we need an index counter self._index_counter = defaultdict(int) def _second_order_timesteps(self, sigmas, log_sigmas): def sigma_fn(_t): return np.exp(-_t) def t_fn(_sigma): return -np.log(_sigma) midpoint_ratio = 0.5 t = t_fn(sigmas) delta_time = np.diff(t) t_proposed = t[:-1] + delta_time * midpoint_ratio sig_proposed = sigma_fn(t_proposed) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed]) return timesteps # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, self.num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas @property def state_in_first_order(self): return self.sample is None def step( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: Union[float, torch.FloatTensor], sample: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, s_noise: float = 1.0, ) -> Union[SchedulerOutput, Tuple]: """ Args: Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). model_output (Union[torch.FloatTensor, np.ndarray]): Direct output from learned diffusion model. timestep (Union[float, torch.FloatTensor]): Current discrete timestep in the diffusion chain. sample (Union[torch.FloatTensor, np.ndarray]): Current instance of sample being created by diffusion process. return_dict (bool, optional): Option for returning tuple rather than SchedulerOutput class. Defaults to True. s_noise (float, optional): Scaling factor for the noise added to the sample. Defaults to 1.0. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ step_index = self.index_for_timestep(timestep) # advance index counter by 1 timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep self._index_counter[timestep_int] += 1 # Create a noise sampler if it hasn't been created yet if self.noise_sampler is None: min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) # Define functions to compute sigma and t from each other def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor: return _t.neg().exp() def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: return _sigma.log().neg() if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_next = self.sigmas[step_index + 1] else: # 2nd order sigma = self.sigmas[step_index - 1] sigma_next = self.sigmas[step_index] # Set the midpoint and step size for the current step midpoint_ratio = 0.5 t, t_next = t_fn(sigma), t_fn(sigma_next) delta_time = t_next - t t_proposed = t + delta_time * midpoint_ratio # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if sigma_next == 0: derivative = (sample - pred_original_sample) / sigma dt = sigma_next - sigma prev_sample = sample + derivative * dt else: if self.state_in_first_order: t_next = t_proposed else: sample = self.sample sigma_from = sigma_fn(t) sigma_to = sigma_fn(t_next) sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 ancestral_t = t_fn(sigma_down) prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( t - ancestral_t ).expm1() * pred_original_sample prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up if self.state_in_first_order: # store for 2nd order step self.sample = sample self.mid_point_sigma = sigma_fn(t_next) else: # free for "first order mode" self.sample = None self.mid_point_sigma = None if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_dpmsolver_singlestep.py ================================================ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import logging from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality samples, and it can generate quite good samples even in only 10 steps. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 Currently, we support the singlestep DPM-Solver for both noise prediction models and data prediction models. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, or `v-prediction`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable this to use up all the function evaluations. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. lambda_min_clipped (`float`, default `-inf`): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.sample = None self.order_list = self.get_order_list(num_train_timesteps) def get_order_list(self, num_inference_steps: int) -> List[int]: """ Computes the solver order at each time step. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ steps = num_inference_steps order = self.config.solver_order if self.config.lower_order_final: if order == 3: if steps % 3 == 0: orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1] elif steps % 3 == 1: orders = [1, 2, 3] * (steps // 3) + [1] else: orders = [1, 2, 3] * (steps // 3) + [1, 2] elif order == 2: if steps % 2 == 0: orders = [1, 2] * (steps // 2) else: orders = [1, 2] * (steps // 2) + [1] elif order == 1: orders = [1] * steps else: if order == 3: orders = [1, 2, 3] * (steps // 3) elif order == 2: orders = [1, 2] * (steps // 2) elif order == 1: orders = [1] * steps return orders def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [None] * self.config.solver_order self.sample = None if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: logger.warn( "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." ) self.register_to_config(lower_order_final=True) self.order_list = self.get_order_list(num_inference_steps) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned_range"]: model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverSinglestepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned_range"]: model_output = model_output[:, :3] return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverSinglestepScheduler." ) def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the first-order DPM-Solver (equivalent to DDIM). See https://arxiv.org/abs/2206.00927 for the detailed derivation. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output return x_t def singlestep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the second-order singlestep DPM-Solver. It computes the solution at time `prev_timestep` from the time `timestep_list[-2]`. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1] sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1] h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m1, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s1) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s1) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s1) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s1) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) return x_t def singlestep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the third-order singlestep DPM-Solver. It computes the solution at time `prev_timestep` from the time `timestep_list[-3]`. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2], ) alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2] sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2] h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m2 D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2) D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1) D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s2) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s2) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s2) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s2) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def singlestep_dpm_solver_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, order: int, ) -> torch.FloatTensor: """ One step for the singlestep DPM-Solver. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. order (`int`): the solver order at this step. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ if order == 1: return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample) elif order == 2: return self.singlestep_dpm_solver_second_order_update( model_output_list, timestep_list, prev_timestep, sample ) elif order == 3: return self.singlestep_dpm_solver_third_order_update( model_output_list, timestep_list, prev_timestep, sample ) else: raise ValueError(f"Order must be 1, 2, 3, got {order}") def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the singlestep DPM-Solver. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] model_output = self.convert_model_output(model_output, timestep, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output order = self.order_list[step_index] # For img2img denoising might start with order>1 which is not possible # In this case make sure that the first two steps are both order=1 while self.model_outputs[-order] is None: order -= 1 # For single-step solvers, we use the initial value at each time with order = 1. if order == 1: self.sample = sample timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep] prev_sample = self.singlestep_dpm_solver_update( self.model_outputs, timestep_list, prev_timestep, self.sample, order ) if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_euler_ancestral_discrete.py ================================================ # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete class EulerAncestralDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.FloatTensor`): input sample timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: scaled input sample """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ ::-1 ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator (`torch.Generator`, optional): Random number generator. return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class Returns: [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) sigma_from = self.sigmas[step_index] sigma_to = self.sigmas[step_index + 1] sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma dt = sigma_down - sigma prev_sample = sample + derivative * dt device = model_output.device noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) prev_sample = prev_sample + noise * sigma_up if not return_dict: return (prev_sample,) return EulerAncestralDiscreteSchedulerOutput( prev_sample=prev_sample, pred_original_sample=pred_original_sample ) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_euler_discrete.py ================================================ # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete class EulerDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `"epsilon"`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) interpolation_type (`str`, default `"linear"`, optional): interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of [`"linear"`, `"log_linear"`]. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.FloatTensor`): input sample timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: scaled input sample """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ ::-1 ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) if self.config.interpolation_type == "linear": sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) elif self.config.interpolation_type == "log_linear": sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp() else: raise ValueError( f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" " 'linear' or 'log_linear'" ) if self.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. s_churn (`float`) s_tmin (`float`) s_tmax (`float`) s_noise (`float`) generator (`torch.Generator`, optional): Random number generator. return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class Returns: [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) eps = noise * s_noise sigma_hat = sigma * (gamma + 1) if gamma > 0: sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # NOTE: "original_sample" should not be an expected prediction_type but is left in for # backwards compatibility if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma_hat * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma_hat dt = self.sigmas[step_index + 1] - sigma_hat prev_sample = sample + derivative * dt if not return_dict: return (prev_sample,) return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_heun_discrete.py ================================================ # Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf). clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, clip_sample: Optional[bool] = False, clip_sample_range: float = 1.0, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine") elif beta_schedule == "exp": self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp") else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) if len(self._index_counter) == 0: pos = 1 if len(indices) > 1 else 0 else: timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep pos = self._index_counter[timestep_int] return indices[pos].item() @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], ) -> torch.FloatTensor: """ Args: Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ step_index = self.index_for_timestep(timestep) sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) timesteps = torch.from_numpy(timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = timesteps.to(device, dtype=torch.float32) else: self.timesteps = timesteps.to(device=device) # empty dt and derivative self.prev_derivative = None self.dt = None # for exp beta schedules, such as the one for `pipeline_shap_e.py` # we need an index counter self._index_counter = defaultdict(int) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas @property def state_in_first_order(self): return self.dt is None def step( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: Union[float, torch.FloatTensor], sample: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Args: Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ step_index = self.index_for_timestep(timestep) # advance index counter by 1 timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep self._index_counter[timestep_int] += 1 if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_next = self.sigmas[step_index + 1] else: # 2nd order / Heun's method sigma = self.sigmas[step_index - 1] sigma_next = self.sigmas[step_index] # currently only gamma=0 is supported. This usually works best anyways. # We can support gamma in the future but then need to scale the timestep before # passing it to the model which requires a change in API gamma = 0 sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma_hat if self.state_in_first_order else sigma_next pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma_hat if self.state_in_first_order else sigma_next pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) if self.state_in_first_order: # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat # 3. delta timestep dt = sigma_next - sigma_hat # store for 2nd order step self.prev_derivative = derivative self.dt = dt self.sample = sample else: # 2. 2nd order / Heun's method derivative = (sample - pred_original_sample) / sigma_next derivative = (self.prev_derivative + derivative) / 2 # 3. take prev timestep & sample dt = self.dt sample = self.sample # free dt and derivative # Note, this puts the scheduler in "first order mode" self.prev_derivative = None self.dt = None self.sample = None prev_sample = sample + derivative * dt if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_ipndm.py ================================================ # Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin, SchedulerOutput class IPNDMScheduler(SchedulerMixin, ConfigMixin): """ Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion [library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296) [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None ): # set `betas`, `alphas`, `timesteps` self.set_timesteps(num_train_timesteps) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 # running values self.ets = [] def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] steps = torch.cat([steps, torch.tensor([0.0])]) if self.config.trained_betas is not None: self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32) else: self.betas = torch.sin(steps * math.pi / 2) ** 2 self.alphas = (1.0 - self.betas**2) ** 0.5 timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] self.timesteps = timesteps.to(device) self.ets = [] def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) timestep_index = (self.timesteps == timestep).nonzero().item() prev_timestep_index = timestep_index + 1 ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index] self.ets.append(ets) if len(self.ets) == 1: ets = self.ets[-1] elif len(self.ets) == 2: ets = (3 * self.ets[-1] - self.ets[-2]) / 2 elif len(self.ets) == 3: ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 else: ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets) if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): alpha = self.alphas[timestep_index] sigma = self.betas[timestep_index] next_alpha = self.alphas[prev_timestep_index] next_sigma = self.betas[prev_timestep_index] pred = (sample - sigma * ets) / max(alpha, 1e-8) prev_sample = next_alpha * pred + ets * next_sigma return prev_sample def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py ================================================ # Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) if len(self._index_counter) == 0: pos = 1 if len(indices) > 1 else 0 else: timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep pos = self._index_counter[timestep_int] return indices[pos].item() @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], ) -> torch.FloatTensor: """ Args: Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ step_index = self.index_for_timestep(timestep) if self.state_in_first_order: sigma = self.sigmas[step_index] else: sigma = self.sigmas_interpol[step_index - 1] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) # compute up and down sigmas sigmas_next = sigmas.roll(-1) sigmas_next[-1] = 0.0 sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5 sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5 sigmas_down[-1] = 0.0 # compute interpolated sigmas sigmas_interpol = sigmas.log().lerp(sigmas_down.log(), 0.5).exp() sigmas_interpol[-2:] = 0.0 # set sigmas self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) self.sigmas_interpol = torch.cat( [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] ) self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: timesteps = torch.from_numpy(timesteps).to(device) timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.sample = None # for exp beta schedules, such as the one for `pipeline_shap_e.py` # we need an index counter self._index_counter = defaultdict(int) def sigma_to_t(self, sigma): # get log sigma log_sigma = sigma.log() # get distribution dists = log_sigma - self.log_sigmas[:, None] # get sigmas range low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = self.log_sigmas[low_idx] high = self.log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = w.clamp(0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.view(sigma.shape) return t @property def state_in_first_order(self): return self.sample is None def step( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: Union[float, torch.FloatTensor], sample: Union[torch.FloatTensor, np.ndarray], generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Args: Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ step_index = self.index_for_timestep(timestep) # advance index counter by 1 timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep self._index_counter[timestep_int] += 1 if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_interpol = self.sigmas_interpol[step_index] sigma_up = self.sigmas_up[step_index] sigma_down = self.sigmas_down[step_index - 1] else: # 2nd order / KPDM2's method sigma = self.sigmas[step_index - 1] sigma_interpol = self.sigmas_interpol[step_index - 1] sigma_up = self.sigmas_up[step_index - 1] sigma_down = self.sigmas_down[step_index - 1] # currently only gamma=0 is supported. This usually works best anyways. # We can support gamma in the future but then need to scale the timestep before # passing it to the model which requires a change in API gamma = 0 sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now device = model_output.device noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if self.state_in_first_order: # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat # 3. delta timestep dt = sigma_interpol - sigma_hat # store for 2nd order step self.sample = sample self.dt = dt prev_sample = sample + derivative * dt else: # DPM-Solver-2 # 2. Convert to an ODE derivative for 2nd order derivative = (sample - pred_original_sample) / sigma_interpol # 3. delta timestep dt = sigma_down - sigma_hat sample = self.sample self.sample = None prev_sample = sample + derivative * dt prev_sample = prev_sample + noise * sigma_up if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_k_dpm_2_discrete.py ================================================ # Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): """ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) if len(self._index_counter) == 0: pos = 1 if len(indices) > 1 else 0 else: timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep pos = self._index_counter[timestep_int] return indices[pos].item() @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], ) -> torch.FloatTensor: """ Args: Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ step_index = self.index_for_timestep(timestep) if self.state_in_first_order: sigma = self.sigmas[step_index] else: sigma = self.sigmas_interpol[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) # interpolate sigmas sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp() self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) self.sigmas_interpol = torch.cat( [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] ) if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: timesteps = torch.from_numpy(timesteps).to(device) # interpolate timesteps timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.sample = None # for exp beta schedules, such as the one for `pipeline_shap_e.py` # we need an index counter self._index_counter = defaultdict(int) def sigma_to_t(self, sigma): # get log sigma log_sigma = sigma.log() # get distribution dists = log_sigma - self.log_sigmas[:, None] # get sigmas range low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = self.log_sigmas[low_idx] high = self.log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = w.clamp(0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.view(sigma.shape) return t @property def state_in_first_order(self): return self.sample is None def step( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: Union[float, torch.FloatTensor], sample: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Args: Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ step_index = self.index_for_timestep(timestep) # advance index counter by 1 timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep self._index_counter[timestep_int] += 1 if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_interpol = self.sigmas_interpol[step_index + 1] sigma_next = self.sigmas[step_index + 1] else: # 2nd order / KDPM2's method sigma = self.sigmas[step_index - 1] sigma_interpol = self.sigmas_interpol[step_index] sigma_next = self.sigmas[step_index] # currently only gamma=0 is supported. This usually works best anyways. # We can support gamma in the future but then need to scale the timestep before # passing it to the model which requires a change in API gamma = 0 sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if self.state_in_first_order: # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat # 3. delta timestep dt = sigma_interpol - sigma_hat # store for 2nd order step self.sample = sample else: # DPM-Solver-2 # 2. Convert to an ODE derivative for 2nd order derivative = (sample - pred_original_sample) / sigma_interpol # 3. delta timestep dt = sigma_next - sigma_hat sample = self.sample self.sample = None prev_sample = sample + derivative * dt if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_karras_ve.py ================================================ # Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import SchedulerMixin @dataclass class KarrasVeOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Derivative of predicted original image sample (x_0). pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor derivative: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None class KarrasVeScheduler(SchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. Args: sigma_min (`float`): minimum noise magnitude sigma_max (`float`): maximum noise magnitude s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. s_churn (`float`): the parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100]. s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). A reasonable range is [0, 10]. s_max (`float`): the end value of the sigma range where we add noise. A reasonable range is [0.2, 80]. """ order = 2 @register_to_config def __init__( self, sigma_min: float = 0.02, sigma_max: float = 100, s_noise: float = 1.007, s_churn: float = 80, s_min: float = 0.05, s_max: float = 50, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max # setable values self.num_inference_steps: int = None self.timesteps: np.IntTensor = None self.schedule: torch.FloatTensor = None # sigma(t_i) def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = torch.from_numpy(timesteps).to(device) schedule = [ ( self.config.sigma_max**2 * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) ) for i in self.timesteps ] self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device) def add_noise_to_input( self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None ) -> Tuple[torch.FloatTensor, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. TODO Args: """ if self.config.s_min <= sigma <= self.config.s_max: gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) else: gamma = 0 # sample eps ~ N(0, S_noise^2 * I) eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device) sigma_hat = sigma + gamma * sigma sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) return sample_hat, sigma_hat def step( self, model_output: torch.FloatTensor, sigma_hat: float, sigma_prev: float, sample_hat: torch.FloatTensor, return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). Returns: [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ pred_original_sample = sample_hat + sigma_hat * model_output derivative = (sample_hat - pred_original_sample) / sigma_hat sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative if not return_dict: return (sample_prev, derivative) return KarrasVeOutput( prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample ) def step_correct( self, model_output: torch.FloatTensor, sigma_hat: float, sigma_prev: float, sample_hat: torch.FloatTensor, sample_prev: torch.FloatTensor, derivative: torch.FloatTensor, return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. TODO complete description Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor`): TODO sample_prev (`torch.FloatTensor`): TODO derivative (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO """ pred_original_sample = sample_prev + sigma_prev * model_output derivative_corr = (sample_prev - pred_original_sample) / sigma_prev sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) if not return_dict: return (sample_prev, derivative) return KarrasVeOutput( prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample ) def add_noise(self, original_samples, noise, timesteps): raise NotImplementedError() ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_karras_ve_flax.py ================================================ # Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from jax import random from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils_flax import FlaxSchedulerMixin @flax.struct.dataclass class KarrasVeSchedulerState: # setable values num_inference_steps: Optional[int] = None timesteps: Optional[jnp.ndarray] = None schedule: Optional[jnp.ndarray] = None # sigma(t_i) @classmethod def create(cls): return cls() @dataclass class FlaxKarrasVeOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Derivative of predicted original image sample (x_0). state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. """ prev_sample: jnp.ndarray derivative: jnp.ndarray state: KarrasVeSchedulerState class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. Args: sigma_min (`float`): minimum noise magnitude sigma_max (`float`): maximum noise magnitude s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. s_churn (`float`): the parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100]. s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). A reasonable range is [0, 10]. s_max (`float`): the end value of the sigma range where we add noise. A reasonable range is [0.2, 80]. """ @property def has_state(self): return True @register_to_config def __init__( self, sigma_min: float = 0.02, sigma_max: float = 100, s_noise: float = 1.007, s_churn: float = 80, s_min: float = 0.05, s_max: float = 50, ): pass def create_state(self): return KarrasVeSchedulerState.create() def set_timesteps( self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> KarrasVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() schedule = [ ( self.config.sigma_max**2 * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) ) for i in timesteps ] return state.replace( num_inference_steps=num_inference_steps, schedule=jnp.array(schedule, dtype=jnp.float32), timesteps=timesteps, ) def add_noise_to_input( self, state: KarrasVeSchedulerState, sample: jnp.ndarray, sigma: float, key: random.KeyArray, ) -> Tuple[jnp.ndarray, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. TODO Args: """ if self.config.s_min <= sigma <= self.config.s_max: gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1) else: gamma = 0 # sample eps ~ N(0, S_noise^2 * I) key = random.split(key, num=1) eps = self.config.s_noise * random.normal(key=key, shape=sample.shape) sigma_hat = sigma + gamma * sigma sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) return sample_hat, sigma_hat def step( self, state: KarrasVeSchedulerState, model_output: jnp.ndarray, sigma_hat: float, sigma_prev: float, sample_hat: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxKarrasVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class Returns: [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ pred_original_sample = sample_hat + sigma_hat * model_output derivative = (sample_hat - pred_original_sample) / sigma_hat sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative if not return_dict: return (sample_prev, derivative, state) return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) def step_correct( self, state: KarrasVeSchedulerState, model_output: jnp.ndarray, sigma_hat: float, sigma_prev: float, sample_hat: jnp.ndarray, sample_prev: jnp.ndarray, derivative: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxKarrasVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. TODO complete description Args: state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.FloatTensor` or `np.ndarray`): TODO return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO """ pred_original_sample = sample_prev + sigma_prev * model_output derivative_corr = (sample_prev - pred_original_sample) / sigma_prev sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) if not return_dict: return (sample_prev, derivative, state) return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps): raise NotImplementedError() ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_lms_discrete.py ================================================ # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import warnings from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete class LMSDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None self.use_karras_sigmas = use_karras_sigmas self.set_timesteps(num_train_timesteps, None) self.derivatives = [] self.is_scale_input_called = False @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. Args: sample (`torch.FloatTensor`): input sample timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: scaled input sample """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample def get_lms_coefficient(self, order, t, current_order): """ Compute a linear multistep coefficient. Args: order (TODO): t (TODO): current_order (TODO): """ def lms_derivative(tau): prod = 1.0 for k in range(order): if current_order == k: continue prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) return prod integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] return integrated_coeff def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ ::-1 ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) self.derivatives = [] # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, self.num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, order: int = 4, return_dict: bool = True, ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class Returns: [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if not self.is_scale_input_called: warnings.warn( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) elif self.config.prediction_type == "sample": pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma self.derivatives.append(derivative) if len(self.derivatives) > order: self.derivatives.pop(0) # 3. Compute linear multistep coefficients order = min(step_index + 1, order) lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path prev_sample = sample + sum( coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) ) if not return_dict: return (prev_sample,) return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_lms_discrete_flax.py ================================================ # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, ) @flax.struct.dataclass class LMSDiscreteSchedulerState: common: CommonSchedulerState # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray sigmas: jnp.ndarray num_inference_steps: Optional[int] = None # running values derivatives: Optional[jnp.ndarray] = None @classmethod def create( cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray ): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) @dataclass class FlaxLMSSchedulerOutput(FlaxSchedulerOutput): state: LMSDiscreteSchedulerState class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState: if common is None: common = CommonSchedulerState.create(self) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 # standard deviation of the initial noise distribution init_noise_sigma = sigmas.max() return LMSDiscreteSchedulerState.create( common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas, ) def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. Args: state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. timestep (`int`): current discrete timestep in the diffusion chain. Returns: `jnp.ndarray`: scaled input sample """ (step_index,) = jnp.where(state.timesteps == timestep, size=1) step_index = step_index[0] sigma = state.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order): """ Compute a linear multistep coefficient. Args: order (TODO): t (TODO): current_order (TODO): """ def lms_derivative(tau): prod = 1.0 for k in range(order): if current_order == k: continue prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k]) return prod integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0] return integrated_coeff def set_timesteps( self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) low_idx = jnp.floor(timesteps).astype(jnp.int32) high_idx = jnp.ceil(timesteps).astype(jnp.int32) frac = jnp.mod(timesteps, 1.0) sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) timesteps = timesteps.astype(jnp.int32) # initial running values derivatives = jnp.zeros((0,) + shape, dtype=self.dtype) return state.replace( timesteps=timesteps, sigmas=sigmas, num_inference_steps=num_inference_steps, derivatives=derivatives, ) def step( self, state: LMSDiscreteSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, order: int = 4, return_dict: bool = True, ) -> Union[FlaxLMSSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class Returns: [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) sigma = state.sigmas[timestep] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma state = state.replace(derivatives=jnp.append(state.derivatives, derivative)) if len(state.derivatives) > order: state = state.replace(derivatives=jnp.delete(state.derivatives, 0)) # 3. Compute linear multistep coefficients order = min(timestep + 1, order) lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path prev_sample = sample + sum( coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives)) ) if not return_dict: return (prev_sample, state) return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, state: LMSDiscreteSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: sigma = state.sigmas[timesteps].flatten() sigma = broadcast_to_shape_from_left(sigma, noise.shape) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_pndm.py ================================================ # Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class PNDMScheduler(SchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. set_alpha_to_one (`bool`, default `False`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, prediction_type: str = "epsilon", timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 # running values self.cur_model_output = 0 self.counter = 0 self.cur_sample = None self.ets = [] # setable values self.num_inference_steps = None self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.prk_timesteps = None self.plms_timesteps = None self.timesteps = None def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": self._timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() self._timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype( np.int64 ) self._timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 self.prk_timesteps = np.array([]) self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[ ::-1 ].copy() else: prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order ) self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() self.plms_timesteps = self._timesteps[:-3][ ::-1 ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.ets = [] self.counter = 0 self.cur_model_output = 0 def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) else: return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) def step_prk( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 prev_timestep = timestep - diff_to_prev timestep = self.prk_timesteps[self.counter // 4 * 4] if self.counter % 4 == 0: self.cur_model_output += 1 / 6 * model_output self.ets.append(model_output) self.cur_sample = sample elif (self.counter - 1) % 4 == 0: self.cur_model_output += 1 / 3 * model_output elif (self.counter - 2) % 4 == 0: self.cur_model_output += 1 / 3 * model_output elif (self.counter - 3) % 4 == 0: model_output = self.cur_model_output + 1 / 6 * model_output self.cur_model_output = 0 # cur_sample should not be `None` cur_sample = self.cur_sample if self.cur_sample is not None else sample prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) self.counter += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def step_plms( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if not self.config.skip_prk_steps and len(self.ets) < 3: raise ValueError( f"{self.__class__} can only be run AFTER scheduler has been run " "in 'prk' mode for at least 12 iterations " "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " "for more information." ) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps if self.counter != 1: self.ets = self.ets[-3:] self.ets.append(model_output) else: prev_timestep = timestep timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps if len(self.ets) == 1 and self.counter == 0: model_output = model_output self.cur_sample = sample elif len(self.ets) == 1 and self.counter == 1: model_output = (model_output + self.ets[-1]) / 2 sample = self.cur_sample self.cur_sample = None elif len(self.ets) == 2: model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 elif len(self.ets) == 3: model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 else: model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) self.counter += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation # Notation ( -> # alpha_prod_t -> α_t # alpha_prod_t_prev -> α_(t−δ) # beta_prod_t -> (1 - α_t) # beta_prod_t_prev -> (1 - α_(t−δ)) # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if self.config.prediction_type == "v_prediction": model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample elif self.config.prediction_type != "epsilon": raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" ) # corresponds to (α_(t−δ) - α_t) divided by # denominator of x_t in formula (9) and plus 1 # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = # sqrt(α_(t−δ)) / sqrt(α_t)) sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) # corresponds to denominator of e_θ(x_t, t) in formula (9) model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( alpha_prod_t * beta_prod_t * alpha_prod_t_prev ) ** (0.5) # full formula (9) prev_sample = ( sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff ) return prev_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_pndm_flax.py ================================================ # Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, ) @flax.struct.dataclass class PNDMSchedulerState: common: CommonSchedulerState final_alpha_cumprod: jnp.ndarray # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None prk_timesteps: Optional[jnp.ndarray] = None plms_timesteps: Optional[jnp.ndarray] = None # running values cur_model_output: Optional[jnp.ndarray] = None counter: Optional[jnp.int32] = None cur_sample: Optional[jnp.ndarray] = None ets: Optional[jnp.ndarray] = None @classmethod def create( cls, common: CommonSchedulerState, final_alpha_cumprod: jnp.ndarray, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, ): return cls( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) @dataclass class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput): state: PNDMSchedulerState class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. set_alpha_to_one (`bool`, default `False`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype pndm_order: int @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, steps_offset: int = 0, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState: if common is None: common = CommonSchedulerState.create(self) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. final_alpha_cumprod = ( jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] ) # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return PNDMSchedulerState.create( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. shape (`Tuple`): the shape of the samples to be generated. """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 prk_timesteps = jnp.array([], dtype=jnp.int32) plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1] else: prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), self.pndm_order, ) prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1] plms_timesteps = _timesteps[:-3][::-1] timesteps = jnp.concatenate([prk_timesteps, plms_timesteps]) # initial running values cur_model_output = jnp.zeros(shape, dtype=self.dtype) counter = jnp.int32(0) cur_sample = jnp.zeros(shape, dtype=self.dtype) ets = jnp.zeros((4,) + shape, dtype=self.dtype) return state.replace( timesteps=timesteps, num_inference_steps=num_inference_steps, prk_timesteps=prk_timesteps, plms_timesteps=plms_timesteps, cur_model_output=cur_model_output, counter=counter, cur_sample=cur_sample, ets=ets, ) def scale_model_input( self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def step( self, state: PNDMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.config.skip_prk_steps: prev_sample, state = self.step_plms(state, model_output, timestep, sample) else: prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample) plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample) cond = state.counter < len(state.prk_timesteps) prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample) state = state.replace( cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output), ets=jax.lax.select(cond, prk_state.ets, plms_state.ets), cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample), counter=jax.lax.select(cond, prk_state.counter, plms_state.counter), ) if not return_dict: return (prev_sample, state) return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state) def step_prk( self, state: PNDMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) diff_to_prev = jnp.where( state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 ) prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] model_output = jax.lax.select( (state.counter % 4) != 3, model_output, # remainder 0, 1, 2 state.cur_model_output + 1 / 6 * model_output, # remainder 3 ) state = state.replace( cur_model_output=jax.lax.select_n( state.counter % 4, state.cur_model_output + 1 / 6 * model_output, # remainder 0 state.cur_model_output + 1 / 3 * model_output, # remainder 1 state.cur_model_output + 1 / 3 * model_output, # remainder 2 jnp.zeros_like(state.cur_model_output), # remainder 3 ), ets=jax.lax.select( (state.counter % 4) == 0, state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0 state.ets, # remainder 1, 2, 3 ), cur_sample=jax.lax.select( (state.counter % 4) == 0, sample, # remainder 0 state.cur_sample, # remainder 1, 2, 3 ), ) cur_sample = state.cur_sample prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) return (prev_sample, state) def step_plms( self, state: PNDMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # NOTE: There is no way to check in the jitted runtime if the prk mode was ran before prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) # Reference: # if state.counter != 1: # state.ets.append(model_output) # else: # prev_timestep = timestep # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) timestep = jnp.where( state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep ) # Reference: # if len(state.ets) == 1 and state.counter == 0: # model_output = model_output # state.cur_sample = sample # elif len(state.ets) == 1 and state.counter == 1: # model_output = (model_output + state.ets[-1]) / 2 # sample = state.cur_sample # state.cur_sample = None # elif len(state.ets) == 2: # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 # elif len(state.ets) == 3: # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 # else: # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]) state = state.replace( ets=jax.lax.select( state.counter != 1, state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1 state.ets, # counter 1 ), cur_sample=jax.lax.select( state.counter != 1, sample, # counter != 1 state.cur_sample, # counter 1 ), ) state = state.replace( cur_model_output=jax.lax.select_n( jnp.clip(state.counter, 0, 4), model_output, # counter 0 (model_output + state.ets[-1]) / 2, # counter 1 (3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2 (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3 (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4 ), ) sample = state.cur_sample model_output = state.cur_model_output prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) return (prev_sample, state) def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation # Notation ( -> # alpha_prod_t -> α_t # alpha_prod_t_prev -> α_(t−δ) # beta_prod_t -> (1 - α_t) # beta_prod_t_prev -> (1 - α_(t−δ)) # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if self.config.prediction_type == "v_prediction": model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample elif self.config.prediction_type != "epsilon": raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" ) # corresponds to (α_(t−δ) - α_t) divided by # denominator of x_t in formula (9) and plus 1 # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = # sqrt(α_(t−δ)) / sqrt(α_t)) sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) # corresponds to denominator of e_θ(x_t, t) in formula (9) model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( alpha_prod_t * beta_prod_t * alpha_prod_t_prev ) ** (0.5) # full formula (9) prev_sample = ( sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff ) return prev_sample def add_noise( self, state: PNDMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_repaint.py ================================================ # Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import SchedulerMixin @dataclass class RePaintSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: torch.FloatTensor # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class RePaintScheduler(SchedulerMixin, ConfigMixin): """ RePaint is a schedule for DDPM inpainting inside a given mask. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. eta (`float`): The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and 1.0 is DDPM scheduler respectively. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", eta: float = 0.0, trained_betas: Optional[np.ndarray] = None, clip_sample: bool = True, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) self.final_alpha_cumprod = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.eta = eta def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps( self, num_inference_steps: int, jump_length: int = 10, jump_n_sample: int = 10, device: Union[str, torch.device] = None, ): num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) self.num_inference_steps = num_inference_steps timesteps = [] jumps = {} for j in range(0, num_inference_steps - jump_length, jump_length): jumps[j] = jump_n_sample - 1 t = num_inference_steps while t >= 1: t = t - 1 timesteps.append(t) if jumps.get(t, 0) > 0: jumps[t] = jumps[t] - 1 for _ in range(jump_length): t = t + 1 timesteps.append(t) timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t): prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get # previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add # variance to pred_sample # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf # without eta. # variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, original_image: torch.FloatTensor, mask: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[RePaintSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. original_image (`torch.FloatTensor`): the original image to inpaint on. mask (`torch.FloatTensor`): the mask where 0.0 values define which part of the original image to inpaint (change). generator (`torch.Generator`, *optional*): random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class Returns: [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 # 3. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we # substitute formula (7) in the algorithm coming from DDPM paper # (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper. # DDIM schedule gives the same results as DDPM with eta = 1.0 # Noise is being reused in 7. and 8., but no impact on quality has # been observed. # 5. Add noise device = model_output.device noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype) std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 variance = 0 if t > 0 and self.eta > 0: variance = std_dev_t * noise # 6. compute "direction pointing to x_t" of formula (12) # from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part if not return_dict: return ( pred_prev_sample, pred_original_sample, ) return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def undo_step(self, sample, timestep, generator=None): n = self.config.num_train_timesteps // self.num_inference_steps for i in range(n): beta = self.betas[timestep + i] if sample.device.type == "mps": # randn does not work reproducibly on mps noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator) noise = noise.to(sample.device) else: noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise return sample def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.") def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_sde_ve.py ================================================ # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch import math from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import SchedulerMixin, SchedulerOutput @dataclass class SdeVeOutput(BaseOutput): """ Output class for the ScoreSdeVeScheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. """ prev_sample: torch.FloatTensor prev_sample_mean: torch.FloatTensor class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. For more information, see the original paper: https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. snr (`float`): coefficient weighting the step from the model_output sample (from the network) to the random noise. sigma_min (`float`): initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the distribution of the data. sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to epsilon. correct_steps (`int`): number of correction steps performed on a produced sample. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 2000, snr: float = 0.15, sigma_min: float = 0.01, sigma_max: float = 1348.0, sampling_eps: float = 1e-5, correct_steps: int = 1, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max # setable values self.timesteps = None self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps( self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None ): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). """ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device) def set_sigmas( self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None ): """ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. The sigmas control the weight of the `drift` and `diffusion` components of sample update. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. sigma_min (`float`, optional): initial noise scale value (overrides value given at Scheduler instantiation). sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). """ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps if self.timesteps is None: self.set_timesteps(num_inference_steps, sampling_eps) self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps) self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps)) self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) def get_adjacent_sigma(self, timesteps, t): return torch.where( timesteps == 0, torch.zeros_like(t.to(timesteps.device)), self.discrete_sigmas[timesteps - 1].to(timesteps.device), ) def step_pred( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SdeVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) timestep = timestep * torch.ones( sample.shape[0], device=sample.device ) # torch.repeat_interleave(timestep, sample.shape[0]) timesteps = (timestep * (len(self.timesteps) - 1)).long() # mps requires indices to be in the same device, so we use cpu as is the default with cuda timesteps = timesteps.to(self.discrete_sigmas.device) sigma = self.discrete_sigmas[timesteps].to(sample.device) adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) drift = torch.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods diffusion = diffusion.flatten() while len(diffusion.shape) < len(sample.shape): diffusion = diffusion.unsqueeze(-1) drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of noise = randn_tensor( sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype ) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean) return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean) def step_correct( self, model_output: torch.FloatTensor, sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device) # compute step size from the model_output, the noise, and the snr grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) # self.repeat_scalar(step_size, sample.shape[0]) # compute corrected sample: model_output term and noise term step_size = step_size.flatten() while len(step_size.shape) < len(sample.shape): step_size = step_size.unsqueeze(-1) prev_sample_mean = sample + step_size * model_output prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples timesteps = timesteps.to(original_samples.device) sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps] noise = ( noise * sigmas[:, None, None, None] if noise is not None else torch.randn_like(original_samples) * sigmas[:, None, None, None] ) noisy_samples = noise + original_samples return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_sde_ve_flax.py ================================================ # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from jax import random from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left @flax.struct.dataclass class ScoreSdeVeSchedulerState: # setable values timesteps: Optional[jnp.ndarray] = None discrete_sigmas: Optional[jnp.ndarray] = None sigmas: Optional[jnp.ndarray] = None @classmethod def create(cls): return cls() @dataclass class FlaxSdeVeOutput(FlaxSchedulerOutput): """ Output class for the ScoreSdeVeScheduler's step function output. Args: state (`ScoreSdeVeSchedulerState`): prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. """ state: ScoreSdeVeSchedulerState prev_sample: jnp.ndarray prev_sample_mean: Optional[jnp.ndarray] = None class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. For more information, see the original paper: https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. snr (`float`): coefficient weighting the step from the model_output sample (from the network) to the random noise. sigma_min (`float`): initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the distribution of the data. sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to epsilon. correct_steps (`int`): number of correction steps performed on a produced sample. """ @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 2000, snr: float = 0.15, sigma_min: float = 0.01, sigma_max: float = 1348.0, sampling_eps: float = 1e-5, correct_steps: int = 1, ): pass def create_state(self): state = ScoreSdeVeSchedulerState.create() return self.set_sigmas( state, self.config.num_train_timesteps, self.config.sigma_min, self.config.sigma_max, self.config.sampling_eps, ) def set_timesteps( self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). """ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps timesteps = jnp.linspace(1, sampling_eps, num_inference_steps) return state.replace(timesteps=timesteps) def set_sigmas( self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None, ) -> ScoreSdeVeSchedulerState: """ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. The sigmas control the weight of the `drift` and `diffusion` components of sample update. Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. sigma_min (`float`, optional): initial noise scale value (overrides value given at Scheduler instantiation). sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). """ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps if state.timesteps is None: state = self.set_timesteps(state, num_inference_steps, sampling_eps) discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps)) sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps]) return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas) def get_adjacent_sigma(self, state, timesteps, t): return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1]) def step_pred( self, state: ScoreSdeVeSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, key: random.KeyArray, return_dict: bool = True, ) -> Union[FlaxSdeVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class Returns: [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.timesteps is None: raise ValueError( "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) timestep = timestep * jnp.ones( sample.shape[0], ) timesteps = (timestep * (len(state.timesteps) - 1)).long() sigma = state.discrete_sigmas[timesteps] adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep) drift = jnp.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods diffusion = diffusion.flatten() diffusion = broadcast_to_shape_from_left(diffusion, sample.shape) drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of key = random.split(key, num=1) noise = random.normal(key=key, shape=sample.shape) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean, state) return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state) def step_correct( self, state: ScoreSdeVeSchedulerState, model_output: jnp.ndarray, sample: jnp.ndarray, key: random.KeyArray, return_dict: bool = True, ) -> Union[FlaxSdeVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class Returns: [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.timesteps is None: raise ValueError( "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction key = random.split(key, num=1) noise = random.normal(key=key, shape=sample.shape) # compute step size from the model_output, the noise, and the snr grad_norm = jnp.linalg.norm(model_output) noise_norm = jnp.linalg.norm(noise) step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * jnp.ones(sample.shape[0]) # compute corrected sample: model_output term and noise term step_size = step_size.flatten() step_size = broadcast_to_shape_from_left(step_size, sample.shape) prev_sample_mean = sample + step_size * model_output prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample, state) return FlaxSdeVeOutput(prev_sample=prev_sample, state=state) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_sde_vp.py ================================================ # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch import math from typing import Union import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import randn_tensor from .scheduling_utils import SchedulerMixin class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): """ The variance preserving stochastic differential equation (SDE) scheduler. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more information, see the original paper: https://arxiv.org/abs/2011.13456 UNDER CONSTRUCTION """ order = 1 @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): self.sigmas = None self.discrete_sigmas = None self.timesteps = None def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) def step_pred(self, score, x, t, generator=None): if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) # TODO(Patrick) better comments + non-PyTorch # postprocess model score log_mean_coeff = ( -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min ) std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) std = std.flatten() while len(std.shape) < len(score.shape): std = std.unsqueeze(-1) score = -score / std # compute dt = -1.0 / len(self.timesteps) beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) beta_t = beta_t.flatten() while len(beta_t.shape) < len(x.shape): beta_t = beta_t.unsqueeze(-1) drift = -0.5 * beta_t * x diffusion = torch.sqrt(beta_t) drift = drift - diffusion**2 * score x_mean = x + drift * dt # add noise noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype) x = x_mean + diffusion * math.sqrt(-dt) * noise return x, x_mean def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_unclip.py ================================================ # Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP class UnCLIPSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class UnCLIPScheduler(SchedulerMixin, ConfigMixin): """ NOTE: do not use this scheduler. The DDPM scheduler has been updated to support the changes made here. This scheduler will be removed and replaced with DDPM. This is a modified DDPM Scheduler specifically for the karlo unCLIP model. This scheduler has some minor variations in how it calculates the learned range variance and dynamically re-calculates betas based off the timesteps it is skipping. The scheduler also uses a slightly different step ratio when computing timesteps to use for inference. See [`~DDPMScheduler`] for more information on DDPM scheduling Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small_log` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between `-clip_sample_range` and `clip_sample_range` for numerical stability. clip_sample_range (`float`, default `1.0`): The range to clip the sample between. See `clip_sample`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `sample` (directly predicting the noisy sample`) """ @register_to_config def __init__( self, num_train_timesteps: int = 1000, variance_type: str = "fixed_small_log", clip_sample: bool = True, clip_sample_range: Optional[float] = 1.0, prediction_type: str = "epsilon", beta_schedule: str = "squaredcos_cap_v2", ): if beta_schedule != "squaredcos_cap_v2": raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'") self.betas = betas_for_alpha_bar(num_train_timesteps) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Note that this scheduler uses a slightly different step ratio than the other diffusers schedulers. The different step ratio is to mimic the original karlo implementation and does not affect the quality or accuracy of the results. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps step_ratio = (self.config.num_train_timesteps - 1) / (self.num_inference_steps - 1) timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance_type=None): if prev_timestep is None: prev_timestep = t - 1 alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if prev_timestep == t - 1: beta = self.betas[t] else: beta = 1 - alpha_prod_t / alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = beta_prod_t_prev / beta_prod_t * beta if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small_log": variance = torch.log(torch.clamp(variance, min=1e-20)) variance = torch.exp(0.5 * variance) elif variance_type == "learned_range": # NOTE difference with DDPM scheduler min_log = variance.log() max_log = beta.log() frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, prev_timestep: Optional[int] = None, generator=None, return_dict: bool = True, ) -> Union[UnCLIPSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at. Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used. generator: random number generator. return_dict (`bool`): option for returning tuple rather than UnCLIPSchedulerOutput class Returns: [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range": model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas if prev_timestep is None: prev_timestep = t - 1 alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if prev_timestep == t - 1: beta = self.betas[t] alpha = self.alphas[t] else: beta = 1 - alpha_prod_t / alpha_prod_t_prev alpha = 1 - beta # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `sample`" " for the UnCLIPScheduler." ) # 3. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = torch.clamp( pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise variance = 0 if t > 0: variance_noise = randn_tensor( model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device ) variance = self._get_variance( t, predicted_variance=predicted_variance, prev_timestep=prev_timestep, ) if self.variance_type == "fixed_small_log": variance = variance elif self.variance_type == "learned_range": variance = (0.5 * variance).exp() else: raise ValueError( f"variance_type given as {self.variance_type} must be one of `fixed_small_log` or `learned_range`" " for the UnCLIPScheduler." ) variance = variance * variance_noise pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample,) return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_unipc_multistep.py ================================================ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied after any off-the-shelf solvers to increase the order of accuracy. For more details, see the original paper: https://arxiv.org/abs/2302.04867 Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. predict_x0 (`bool`, default `True`): whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details solver_type (`str`, default `bh2`): the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2` otherwise. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. disable_corrector (`list`, default `[]`): decide which step to disable the corrector. For large guidance scale, the misalignment between the `epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated by disable the corrector at the first few steps (e.g., disable_corrector=[0]) solver_p (`SchedulerMixin`, default `None`): can be any other scheduler. If specified, the algorithm will become solver_p + UniC. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, predict_x0: bool = True, solver_type: str = "bh2", lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: self.register_to_config(solver_type="bh2") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") self.predict_x0 = predict_x0 # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.timestep_list = [None] * solver_order self.lower_order_nums = 0 self.disable_corrector = disable_corrector self.solver_p = solver_p self.last_sample = None def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) self.sigmas = torch.from_numpy(sigmas) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 self.last_sample = None if self.solver_p: self.solver_p.set_timesteps(self.num_inference_steps, device=device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: r""" Convert the model output to the corresponding type that the algorithm PC needs. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ if self.predict_x0: if self.config.prediction_type == "epsilon": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the UniPCMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred else: if self.config.prediction_type == "epsilon": return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the UniPCMultistepScheduler." ) def multistep_uni_p_bh_update( self, model_output: torch.FloatTensor, prev_timestep: int, sample: torch.FloatTensor, order: int, ) -> torch.FloatTensor: """ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. Args: model_output (`torch.FloatTensor`): direct outputs from learned diffusion model at the current timestep. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. order (`int`): the order of UniP at this step, also the p in UniPC-p. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ timestep_list = self.timestep_list model_output_list = self.model_outputs s0, t = self.timestep_list[-1], prev_timestep m0 = model_output_list[-1] x = sample if self.solver_p: x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s0 device = sample.device rks = [] D1s = [] for i in range(1, order): si = timestep_list[-(i + 1)] mi = model_output_list[-(i + 1)] lambda_si = self.lambda_t[si] rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) # for order 2, we use a simplified version if order == 2: rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) else: D1s = None if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - alpha_t * B_h * pred_res else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - sigma_t * B_h * pred_res x_t = x_t.to(x.dtype) return x_t def multistep_uni_c_bh_update( self, this_model_output: torch.FloatTensor, this_timestep: int, last_sample: torch.FloatTensor, this_sample: torch.FloatTensor, order: int, ) -> torch.FloatTensor: """ One step for the UniC (B(h) version). Args: this_model_output (`torch.FloatTensor`): the model outputs at `x_t` this_timestep (`int`): the current timestep `t` last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}` this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}` order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy should be order + 1 Returns: `torch.FloatTensor`: the corrected sample tensor at the current timestep. """ timestep_list = self.timestep_list model_output_list = self.model_outputs s0, t = timestep_list[-1], this_timestep m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s0 device = this_sample.device rks = [] D1s = [] for i in range(1, order): si = timestep_list[-(i + 1)] mi = model_output_list[-(i + 1)] lambda_si = self.lambda_t[si] rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) else: D1s = None # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_c = torch.linalg.solve(R, b) if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) x_t = x_t.to(x.dtype) return x_t def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the multistep UniPC. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() use_corrector = ( step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None ) model_output_convert = self.convert_model_output(model_output, timestep, sample) if use_corrector: sample = self.multistep_uni_c_bh_update( this_model_output=model_output_convert, this_timestep=timestep, last_sample=self.last_sample, this_sample=sample, order=self.this_order, ) # now prepare to run the predictor prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] self.model_outputs[-1] = model_output_convert self.timestep_list[-1] = timestep if self.config.lower_order_final: this_order = min(self.config.solver_order, len(self.timesteps) - step_index) else: this_order = self.config.solver_order self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep assert self.this_order > 0 self.last_sample = sample prev_sample = self.multistep_uni_p_bh_update( model_output=model_output, # pass the original non-converted model output, in case solver-p is used prev_timestep=prev_timestep, sample=sample, order=self.this_order, ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import os from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, Union import torch from ..utils import BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" # NOTE: We make this type an enum because it simplifies usage in docs and prevents # circular imports when used for `_compatibles` within the schedulers module. # When it's used as a type in pipelines, it really is a Union because the actual # scheduler instance is passed in. class KarrasDiffusionSchedulers(Enum): DDIMScheduler = 1 DDPMScheduler = 2 PNDMScheduler = 3 LMSDiscreteScheduler = 4 EulerDiscreteScheduler = 5 HeunDiscreteScheduler = 6 EulerAncestralDiscreteScheduler = 7 DPMSolverMultistepScheduler = 8 DPMSolverSinglestepScheduler = 9 KDPM2DiscreteScheduler = 10 KDPM2AncestralDiscreteScheduler = 11 DEISMultistepScheduler = 12 UniPCMultistepScheduler = 13 DPMSolverSDEScheduler = 14 @dataclass class SchedulerOutput(BaseOutput): """ Base class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.FloatTensor class SchedulerMixin: """ Mixin containing common functions for the schedulers. Class attributes: - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that `from_config` can be used from a class different than the one used to save the config (should be overridden by parent class). """ config_name = SCHEDULER_CONFIG_NAME _compatibles = [] has_compatibles = True @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Dict[str, Any] = None, subfolder: Optional[str] = None, return_unused_kwargs=False, **kwargs, ): r""" Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - A path to a *directory* containing the schedluer configurations saved using [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. subfolder (`str`, *optional*): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether kwargs that are not consumed by the Python class should be returned or not. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a firewalled environment. """ config, kwargs, commit_hash = cls.load_config( pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, return_unused_kwargs=True, return_commit_hash=True, **kwargs, ) return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the [`~SchedulerMixin.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). """ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) @property def compatibles(self): """ Returns all schedulers that are compatible with this scheduler Returns: `List[SchedulerMixin]`: List of compatible schedulers """ return self._get_compatibles() @classmethod def _get_compatibles(cls): compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) diffusers_library = importlib.import_module(__name__.split(".")[0]) compatible_classes = [ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_utils_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import math import os from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, Tuple, Union import flax import jax.numpy as jnp from ..utils import BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" # NOTE: We make this type an enum because it simplifies usage in docs and prevents # circular imports when used for `_compatibles` within the schedulers module. # When it's used as a type in pipelines, it really is a Union because the actual # scheduler instance is passed in. class FlaxKarrasDiffusionSchedulers(Enum): FlaxDDIMScheduler = 1 FlaxDDPMScheduler = 2 FlaxPNDMScheduler = 3 FlaxLMSDiscreteScheduler = 4 FlaxDPMSolverMultistepScheduler = 5 @dataclass class FlaxSchedulerOutput(BaseOutput): """ Base class for the scheduler's step function output. Args: prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: jnp.ndarray class FlaxSchedulerMixin: """ Mixin containing common functions for the schedulers. Class attributes: - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that `from_config` can be used from a class different than the one used to save the config (should be overridden by parent class). """ config_name = SCHEDULER_CONFIG_NAME ignore_for_config = ["dtype"] _compatibles = [] has_compatibles = True @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Dict[str, Any] = None, subfolder: Optional[str] = None, return_unused_kwargs=False, **kwargs, ): r""" Instantiate a Scheduler class from a pre-defined JSON-file. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. subfolder (`str`, *optional*): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether kwargs that are not consumed by the Python class should be returned or not. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a firewalled environment. """ config, kwargs = cls.load_config( pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, return_unused_kwargs=True, **kwargs, ) scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False): state = scheduler.create_state() if return_unused_kwargs: return scheduler, state, unused_kwargs return scheduler, state def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the [`~FlaxSchedulerMixin.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). """ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) @property def compatibles(self): """ Returns all schedulers that are compatible with this scheduler Returns: `List[SchedulerMixin]`: List of compatible schedulers """ return self._get_compatibles() @classmethod def _get_compatibles(cls): compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) diffusers_library = importlib.import_module(__name__.split(".")[0]) compatible_classes = [ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: assert len(shape) >= x.ndim return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. Returns: betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return jnp.array(betas, dtype=dtype) @flax.struct.dataclass class CommonSchedulerState: alphas: jnp.ndarray betas: jnp.ndarray alphas_cumprod: jnp.ndarray @classmethod def create(cls, scheduler): config = scheduler.config if config.trained_betas is not None: betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype) elif config.beta_schedule == "linear": betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype) elif config.beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. betas = ( jnp.linspace( config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype ) ** 2 ) elif config.beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype) else: raise NotImplementedError( f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}" ) alphas = 1.0 - betas alphas_cumprod = jnp.cumprod(alphas, axis=0) return cls( alphas=alphas, betas=betas, alphas_cumprod=alphas_cumprod, ) def get_sqrt_alpha_prod( state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray ): alphas_cumprod = state.alphas_cumprod sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) return sqrt_alpha_prod, sqrt_one_minus_alpha_prod def add_noise_common( state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray ): sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray): sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity ================================================ FILE: 4DoF/diffusers/schedulers/scheduling_vq_diffusion.py ================================================ # Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils import SchedulerMixin @dataclass class VQDiffusionSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`): Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.LongTensor def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor: """ Convert batch of vector of class indices into batch of log onehot vectors Args: x (`torch.LongTensor` of shape `(batch size, vector length)`): Batch of class indices num_classes (`int`): number of classes to be used for the onehot vectors Returns: `torch.FloatTensor` of shape `(batch size, num classes, vector length)`: Log onehot vectors """ x_onehot = F.one_hot(x, num_classes) x_onehot = x_onehot.permute(0, 2, 1) log_x = torch.log(x_onehot.float().clamp(min=1e-30)) return log_x def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor: """ Apply gumbel noise to `logits` """ uniform = torch.rand(logits.shape, device=logits.device, generator=generator) gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) noised = gumbel_noise + logits return noised def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009): """ Cumulative and non-cumulative alpha schedules. See section 4.1. """ att = ( np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start) + alpha_cum_start ) att = np.concatenate(([1], att)) at = att[1:] / att[:-1] att = np.concatenate((att[1:], [1])) return at, att def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999): """ Cumulative and non-cumulative gamma schedules. See section 4.1. """ ctt = ( np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start) + gamma_cum_start ) ctt = np.concatenate(([0], ctt)) one_minus_ctt = 1 - ctt one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] ct = 1 - one_minus_ct ctt = np.concatenate((ctt[1:], [0])) return ct, ctt class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): """ The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image. The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous diffusion timestep. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2111.14822 Args: num_vec_classes (`int`): The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. num_train_timesteps (`int`): Number of diffusion steps used to train the model. alpha_cum_start (`float`): The starting cumulative alpha value. alpha_cum_end (`float`): The ending cumulative alpha value. gamma_cum_start (`float`): The starting cumulative gamma value. gamma_cum_end (`float`): The ending cumulative gamma value. """ order = 1 @register_to_config def __init__( self, num_vec_classes: int, num_train_timesteps: int = 100, alpha_cum_start: float = 0.99999, alpha_cum_end: float = 0.000009, gamma_cum_start: float = 0.000009, gamma_cum_end: float = 0.99999, ): self.num_embed = num_vec_classes # By convention, the index for the mask class is the last class index self.mask_class = self.num_embed - 1 at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end) ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end) num_non_mask_classes = self.num_embed - 1 bt = (1 - at - ct) / num_non_mask_classes btt = (1 - att - ctt) / num_non_mask_classes at = torch.tensor(at.astype("float64")) bt = torch.tensor(bt.astype("float64")) ct = torch.tensor(ct.astype("float64")) log_at = torch.log(at) log_bt = torch.log(bt) log_ct = torch.log(ct) att = torch.tensor(att.astype("float64")) btt = torch.tensor(btt.astype("float64")) ctt = torch.tensor(ctt.astype("float64")) log_cumprod_at = torch.log(att) log_cumprod_bt = torch.log(btt) log_cumprod_ct = torch.log(ctt) self.log_at = log_at.float() self.log_bt = log_bt.float() self.log_ct = log_ct.float() self.log_cumprod_at = log_cumprod_at.float() self.log_cumprod_bt = log_cumprod_bt.float() self.log_cumprod_ct = log_cumprod_ct.float() # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`): device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on. """ self.num_inference_steps = num_inference_steps timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = torch.from_numpy(timesteps).to(device) self.log_at = self.log_at.to(device) self.log_bt = self.log_bt.to(device) self.log_ct = self.log_ct.to(device) self.log_cumprod_at = self.log_cumprod_at.to(device) self.log_cumprod_bt = self.log_cumprod_bt.to(device) self.log_cumprod_ct = self.log_cumprod_ct.to(device) def step( self, model_output: torch.FloatTensor, timestep: torch.long, sample: torch.LongTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[VQDiffusionSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed. Args: log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): The log probabilities for the predicted classes of the initial latent pixels. Does not include a prediction for the masked class as the initial unnoised image cannot be masked. t (`torch.long`): The timestep that determines which transition matrices are used. x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t` generator: (`torch.Generator` or None): RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from. return_dict (`bool`): option for returning tuple rather than VQDiffusionSchedulerOutput class Returns: [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if timestep == 0: log_p_x_t_min_1 = model_output else: log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep) log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) x_t_min_1 = log_p_x_t_min_1.argmax(dim=1) if not return_dict: return (x_t_min_1,) return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1) def q_posterior(self, log_p_x_0, x_t, t): """ Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only forward probabilities. Equation (11) stated in terms of forward probabilities via Equation (5): Where: - the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0) p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) Args: log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): The log probabilities for the predicted classes of the initial latent pixels. Does not include a prediction for the masked class as the initial unnoised image cannot be masked. x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t` t (torch.Long): The timestep that determines which transition matrix is used. Returns: `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`: The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). """ log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class( t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True ) log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class( t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False ) # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) # . . . # . . . # . . . # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) q = log_p_x_0 - log_q_x_t_given_x_0 # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... , # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n # . . . # . . . # . . . # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n q = q - q_log_sum_exp # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} # . . . # . . . # . . . # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} # c_cumulative_{t-1} ... c_cumulative_{t-1} q = self.apply_cumulative_transitions(q, t - 1) # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n # . . . # . . . # . . . # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp # For each column, there are two possible cases. # # Where: # - sum(p_n(x_0))) is summing over all classes for x_0 # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's) # - C_j is the class transitioning to # # 1. x_t is masked i.e. x_t = c_k # # Simplifying the expression, the column vector is: # . # . # . # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0))) # . # . # . # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0)) # # From equation (11) stated in terms of forward probabilities, the last row is trivially verified. # # For the other rows, we can state the equation as ... # # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})] # # This verifies the other rows. # # 2. x_t is not masked # # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i: # . # . # . # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) # . # . # . # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) # . # . # . # 0 # # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities. return log_p_x_t_min_1 def log_Q_t_transitioning_to_known_class( self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool ): """ Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each latent pixel in `x_t`. See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs. Args: t (torch.Long): The timestep that determines which transition matrix is used. x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t`. log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`): The log one-hot vectors of `x_t` cumulative (`bool`): If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`, we use the cumulative transition matrix `0`->`t`. Returns: `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`: Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability transition matrix. When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be masked. Where: - `q_n` is the probability distribution for the forward process of the `n`th latent pixel. - C_0 is a class of a latent pixel embedding - C_k is the class of the masked latent pixel non-cumulative result (omitting logarithms): ``` q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0) . . . . . . . . . q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k) ``` cumulative result (omitting logarithms): ``` q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0) . . . . . . . . . q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1}) ``` """ if cumulative: a = self.log_cumprod_at[t] b = self.log_cumprod_bt[t] c = self.log_cumprod_ct[t] else: a = self.log_at[t] b = self.log_bt[t] c = self.log_ct[t] if not cumulative: # The values in the onehot vector can also be used as the logprobs for transitioning # from masked latent pixels. If we are not calculating the cumulative transitions, # we need to save these vectors to be re-appended to the final matrix so the values # aren't overwritten. # # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector # if x_t is not masked # # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector # if x_t is masked log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1) # `index_to_log_onehot` will add onehot vectors for masked pixels, # so the default one hot matrix has one too many rows. See the doc string # for an explanation of the dimensionality of the returned matrix. log_onehot_x_t = log_onehot_x_t[:, :-1, :] # this is a cheeky trick to produce the transition probabilities using log one-hot vectors. # # Don't worry about what values this sets in the columns that mark transitions # to masked latent pixels. They are overwrote later with the `mask_class_mask`. # # Looking at the below logspace formula in non-logspace, each value will evaluate to either # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column # or # `0 * a + b = b` where `log_Q_t` has the 0 values in the column. # # See equation 7 for more details. log_Q_t = (log_onehot_x_t + a).logaddexp(b) # The whole column of each masked pixel is `c` mask_class_mask = x_t == self.mask_class mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1) log_Q_t[mask_class_mask] = c if not cumulative: log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1) return log_Q_t def apply_cumulative_transitions(self, q, t): bsz = q.shape[0] a = self.log_cumprod_at[t] b = self.log_cumprod_bt[t] c = self.log_cumprod_ct[t] num_latent_pixels = q.shape[2] c = c.expand(bsz, 1, num_latent_pixels) q = (q + a).logaddexp(b) q = torch.cat((q, c), dim=1) return q ================================================ FILE: 4DoF/diffusers/training_utils.py ================================================ import contextlib import copy import random from typing import Any, Dict, Iterable, Optional, Union import numpy as np import torch from .utils import deprecate, is_transformers_available if is_transformers_available(): import transformers def set_seed(seed: int): """ Args: Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. seed (`int`): The seed to set. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # ^^ safe to call this function even if cuda is not available # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights """ def __init__( self, parameters: Iterable[torch.nn.Parameter], decay: float = 0.9999, min_decay: float = 0.0, update_after_step: int = 0, use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, model_cls: Optional[Any] = None, model_config: Dict[str, Any] = None, **kwargs, ): """ Args: parameters (Iterable[torch.nn.Parameter]): The parameters to track. decay (float): The decay factor for the exponential moving average. min_decay (float): The minimum decay factor for the exponential moving average. update_after_step (int): The number of steps to wait before starting to update the EMA weights. use_ema_warmup (bool): Whether to use EMA warmup. inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA weights will be stored on CPU. @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). """ if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " "Please pass the parameters of the module instead." ) deprecate( "passing a `torch.nn.Module` to `ExponentialMovingAverage`", "1.0.0", deprecation_message, standard_warn=False, ) parameters = parameters.parameters() # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility use_ema_warmup = True if kwargs.get("max_value", None) is not None: deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) decay = kwargs["max_value"] if kwargs.get("min_value", None) is not None: deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) min_decay = kwargs["min_value"] parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] if kwargs.get("device", None) is not None: deprecation_message = "The `device` argument is deprecated. Please use `to` instead." deprecate("device", "1.0.0", deprecation_message, standard_warn=False) self.to(device=kwargs["device"]) self.temp_stored_params = None self.decay = decay self.min_decay = min_decay self.update_after_step = update_after_step self.use_ema_warmup = use_ema_warmup self.inv_gamma = inv_gamma self.power = power self.optimization_step = 0 self.cur_decay_value = None # set in `step()` self.model_cls = model_cls self.model_config = model_config @classmethod def from_pretrained(cls, path, model_cls) -> "EMAModel": _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path) ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) ema_model.load_state_dict(ema_kwargs) return ema_model def save_pretrained(self, path): if self.model_cls is None: raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") if self.model_config is None: raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") model = self.model_cls.from_config(self.model_config) state_dict = self.state_dict() state_dict.pop("shadow_params", None) model.register_to_config(**state_dict) self.copy_to(model.parameters()) model.save_pretrained(path) def get_decay(self, optimization_step: int) -> float: """ Compute the decay factor for the exponential moving average. """ step = max(0, optimization_step - self.update_after_step - 1) if step <= 0: return 0.0 if self.use_ema_warmup: cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power else: cur_decay_value = (1 + step) / (10 + step) cur_decay_value = min(cur_decay_value, self.decay) # make sure decay is not smaller than min_decay cur_decay_value = max(cur_decay_value, self.min_decay) return cur_decay_value @torch.no_grad() def step(self, parameters: Iterable[torch.nn.Parameter]): if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " "Please pass the parameters of the module instead." ) deprecate( "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", "1.0.0", deprecation_message, standard_warn=False, ) parameters = parameters.parameters() parameters = list(parameters) self.optimization_step += 1 # Compute the decay factor for the exponential moving average. decay = self.get_decay(self.optimization_step) self.cur_decay_value = decay one_minus_decay = 1 - decay context_manager = contextlib.nullcontext if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed for s_param, param in zip(self.shadow_params, parameters): if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) with context_manager(): if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.to(param.device).data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] def state_dict(self) -> dict: r""" Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during checkpointing to save the ema state dict. """ # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "min_decay": self.min_decay, "optimization_step": self.optimization_step, "update_after_step": self.update_after_step, "use_ema_warmup": self.use_ema_warmup, "inv_gamma": self.inv_gamma, "power": self.power, "shadow_params": self.shadow_params, } def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Save the current parameters for restoring later. parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the former parameters. parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") for c_param, param in zip(self.temp_stored_params, parameters): param.data.copy_(c_param.data) # Better memory-wise. self.temp_stored_params = None def load_state_dict(self, state_dict: dict) -> None: r""" Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict.get("decay", self.decay) if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.min_decay = state_dict.get("min_decay", self.min_decay) if not isinstance(self.min_decay, float): raise ValueError("Invalid min_decay") self.optimization_step = state_dict.get("optimization_step", self.optimization_step) if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") self.update_after_step = state_dict.get("update_after_step", self.update_after_step) if not isinstance(self.update_after_step, int): raise ValueError("Invalid update_after_step") self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) if not isinstance(self.use_ema_warmup, bool): raise ValueError("Invalid use_ema_warmup") self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) if not isinstance(self.inv_gamma, (float, int)): raise ValueError("Invalid inv_gamma") self.power = state_dict.get("power", self.power) if not isinstance(self.power, (float, int)): raise ValueError("Invalid power") shadow_params = state_dict.get("shadow_params", None) if shadow_params is not None: self.shadow_params = shadow_params if not isinstance(self.shadow_params, list): raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") ================================================ FILE: 4DoF/diffusers/utils/__init__.py ================================================ # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from packaging import version from .. import __version__ from .accelerate_utils import apply_forward_hook from .constants import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, DIFFUSERS_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, FLAX_WEIGHTS_NAME, HF_MODULES_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, TEXT_ENCODER_ATTN_MODULE, WEIGHTS_NAME, ) from .deprecation_utils import deprecate from .doc_utils import replace_example_docstring from .dynamic_modules_utils import get_class_from_dynamic_module from .hub_utils import ( HF_HUB_OFFLINE, _add_variant, _get_model_file, extract_commit_hash, http_user_agent, ) from .import_utils import ( BACKENDS_MAPPING, ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, USE_JAX, USE_TF, USE_TORCH, DummyObject, OptionalDependencyNotAvailable, is_accelerate_available, is_accelerate_version, is_bs4_available, is_flax_available, is_ftfy_available, is_inflect_available, is_invisible_watermark_available, is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, is_note_seq_available, is_omegaconf_available, is_onnx_available, is_safetensors_available, is_scipy_available, is_tensorboard_available, is_tf_available, is_torch_available, is_torch_version, is_torchsde_available, is_transformers_available, is_transformers_version, is_unidecode_available, is_wandb_available, is_xformers_available, requires_backends, ) from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil from .torch_utils import is_compiled_module, randn_tensor if is_torch_available(): from .testing_utils import ( floats_tensor, load_hf_numpy, load_image, load_numpy, load_pt, nightly, parse_flag_from_env, print_tensor_test, require_torch_2, require_torch_gpu, skip_mps, slow, torch_all_close, torch_device, ) from .torch_utils import maybe_allow_in_graph from .testing_utils import export_to_gif, export_to_video logger = get_logger(__name__) def check_min_version(min_version): if version.parse(__version__) < version.parse(min_version): if "dev" in min_version: error_message = ( "This example requires a source install from HuggingFace diffusers (see " "`https://huggingface.co/docs/diffusers/installation#install-from-source`)," ) else: error_message = f"This example requires a minimum version of {min_version}," error_message += f" but the version found is {__version__}.\n" raise ImportError(error_message) ================================================ FILE: 4DoF/diffusers/utils/accelerate_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Accelerate utilities: Utilities related to accelerate """ from packaging import version from .import_utils import is_accelerate_available if is_accelerate_available(): import accelerate def apply_forward_hook(method): """ Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. :param method: The method to decorate. This method should be a method of a PyTorch module. """ if not is_accelerate_available(): return method accelerate_version = version.parse(accelerate.__version__).base_version if version.parse(accelerate_version) < version.parse("0.17.0"): return method def wrapper(self, *args, **kwargs): if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): self._hf_hook.pre_forward(self) return method(self, *args, **kwargs) return wrapper ================================================ FILE: 4DoF/diffusers/utils/constants.py ================================================ # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home default_cache_path = HUGGINGFACE_HUB_CACHE CONFIG_NAME = "config.json" WEIGHTS_NAME = "diffusion_pytorch_model.bin" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] TEXT_ENCODER_ATTN_MODULE = ".self_attn" ================================================ FILE: 4DoF/diffusers/utils/deprecation_utils.py ================================================ import inspect import warnings from typing import Any, Dict, Optional, Union from packaging import version def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): from .. import __version__ deprecated_kwargs = take_from values = () if not isinstance(args[0], tuple): args = (args,) for attribute, version_name, message in args: if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): raise ValueError( f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" f" version {__version__} is >= {version_name}" ) warning = None if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: values += (deprecated_kwargs.pop(attribute),) warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." elif hasattr(deprecated_kwargs, attribute): values += (getattr(deprecated_kwargs, attribute),) warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." elif deprecated_kwargs is None: warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." if warning is not None: warning = warning + " " if standard_warn else "" warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: call_frame = inspect.getouterframes(inspect.currentframe())[1] filename = call_frame.filename line_number = call_frame.lineno function = call_frame.function key, value = next(iter(deprecated_kwargs.items())) raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") if len(values) == 0: return elif len(values) == 1: return values[0] return values ================================================ FILE: 4DoF/diffusers/utils/doc_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Doc utilities: Utilities related to documentation """ import re def replace_example_docstring(example_docstring): def docstring_decorator(fn): func_doc = fn.__doc__ lines = func_doc.split("\n") i = 0 while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None: i += 1 if i < len(lines): lines[i] = example_docstring func_doc = "\n".join(lines) else: raise ValueError( f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, " f"current docstring is:\n{func_doc}" ) fn.__doc__ = func_doc return fn return docstring_decorator ================================================ FILE: 4DoF/diffusers/utils/dummy_flax_and_transformers_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) class FlaxStableDiffusionPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_flax_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class FlaxControlNetModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxModelMixin(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxUNet2DConditionModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxAutoencoderKL(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDiffusionPipeline(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDDIMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDDPMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxKarrasVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxLMSDiscreteScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxPNDMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxSchedulerMixin(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_note_seq_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class MidiProcessor(metaclass=DummyObject): _backends = ["note_seq"] def __init__(self, *args, **kwargs): requires_backends(self, ["note_seq"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["note_seq"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["note_seq"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_onnx_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class OnnxRuntimeModel(metaclass=DummyObject): _backends = ["onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["onnx"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_pt_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class AutoencoderKL(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ModelMixin(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class T5FilmDecoder(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet1DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet3DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class VQModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) def get_constant_schedule_with_warmup(*args, **kwargs): requires_backends(get_constant_schedule_with_warmup, ["torch"]) def get_cosine_schedule_with_warmup(*args, **kwargs): requires_backends(get_cosine_schedule_with_warmup, ["torch"]) def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) def get_linear_schedule_with_warmup(*args, **kwargs): requires_backends(get_linear_schedule_with_warmup, ["torch"]) def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"]) def get_scheduler(*args, **kwargs): requires_backends(get_scheduler, ["torch"]) class AudioPipelineOutput(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ConsistencyModelPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DanceDiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDPMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DiTPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ImagePipelineOutput(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KarrasVePipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class LDMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class LDMSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class PNDMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class RePaintPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ScoreSdeVePipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class CMStochasticIterativeScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMParallelScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDPMParallelScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDPMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DEISMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DPMSolverSinglestepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EulerAncestralDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EulerDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class HeunDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class IPNDMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KarrasVeScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KDPM2AncestralDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KDPM2DiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class PNDMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class RePaintScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class SchedulerMixin(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UnCLIPScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UniPCMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class VQDiffusionScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EMAModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_torch_and_librosa_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class AudioDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "librosa"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "librosa"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) class Mel(metaclass=DummyObject): _backends = ["torch", "librosa"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "librosa"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_torch_and_scipy_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class LMSDiscreteScheduler(metaclass=DummyObject): _backends = ["torch", "scipy"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "scipy"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "scipy"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "scipy"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_torch_and_torchsde_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class DPMSolverSDEScheduler(metaclass=DummyObject): _backends = ["torch", "torchsde"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "torchsde"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "torchsde"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "torchsde"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "invisible_watermark"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) class StableDiffusionXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "invisible_watermark"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class StableDiffusionKDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "k_diffusion"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "k_diffusion"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "k_diffusion"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "k_diffusion"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class StableDiffusionOnnxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_torch_and_transformers_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AltDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AudioLDMPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFInpaintingPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class ImageTextPipelineOutput(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyPriorPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22ControlnetImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22ControlnetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22Img2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22InpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22PriorEmb2EmbPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22PriorPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class LDMTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class PaintByExamplePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class SemanticStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class ShapEImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class ShapEPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionDiffEditPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionLatentUpscalePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionLDM3DPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionModelEditingPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPanoramaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionParadigmsPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPipelineSafe(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionSAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionUpscalePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableUnCLIPPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class TextToVideoSDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class TextToVideoZeroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UnCLIPImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UnCLIPPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UniDiffuserModel(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UniDiffuserPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UniDiffuserTextDecoder(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VideoToVideoSDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VQDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) ================================================ FILE: 4DoF/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class SpectrogramDiffusionPipeline(metaclass=DummyObject): _backends = ["transformers", "torch", "note_seq"] def __init__(self, *args, **kwargs): requires_backends(self, ["transformers", "torch", "note_seq"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["transformers", "torch", "note_seq"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["transformers", "torch", "note_seq"]) ================================================ FILE: 4DoF/diffusers/utils/dynamic_modules_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities to dynamically load objects from the Hub.""" import importlib import inspect import json import os import re import shutil import sys from pathlib import Path from typing import Dict, Optional, Union from urllib import request from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info from packaging import version from .. import __version__ from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging COMMUNITY_PIPELINES_URL = ( "https://raw.githubusercontent.com/huggingface/diffusers/{revision}/examples/community/{pipeline}.py" ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_diffusers_versions(): url = "https://pypi.org/pypi/diffusers/json" releases = json.loads(request.urlopen(url).read())["releases"].keys() return sorted(releases, key=lambda x: version.Version(x)) def init_hf_modules(): """ Creates the cache directory for modules with an init, and adds it to the Python path. """ # This function has already been executed if HF_MODULES_CACHE already is in the Python path. if HF_MODULES_CACHE in sys.path: return sys.path.append(HF_MODULES_CACHE) os.makedirs(HF_MODULES_CACHE, exist_ok=True) init_path = Path(HF_MODULES_CACHE) / "__init__.py" if not init_path.exists(): init_path.touch() def create_dynamic_module(name: Union[str, os.PathLike]): """ Creates a dynamic module in the cache directory for modules. """ init_hf_modules() dynamic_module_path = Path(HF_MODULES_CACHE) / name # If the parent module does not exist yet, recursively create it. if not dynamic_module_path.parent.exists(): create_dynamic_module(dynamic_module_path.parent) os.makedirs(dynamic_module_path, exist_ok=True) init_path = dynamic_module_path / "__init__.py" if not init_path.exists(): init_path.touch() def get_relative_imports(module_file): """ Get the list of modules that are relatively imported in a module file. Args: module_file (`str` or `os.PathLike`): The module file to inspect. """ with open(module_file, "r", encoding="utf-8") as f: content = f.read() # Imports of the form `import .xxx` relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) # Imports of the form `from .xxx import yyy` relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) # Unique-ify return list(set(relative_imports)) def get_relative_import_files(module_file): """ Get the list of all files that are needed for a given module. Note that this function recurses through the relative imports (if a imports b and b imports c, it will return module files for b and c). Args: module_file (`str` or `os.PathLike`): The module file to inspect. """ no_change = False files_to_check = [module_file] all_relative_imports = [] # Let's recurse through all relative imports while not no_change: new_imports = [] for f in files_to_check: new_imports.extend(get_relative_imports(f)) module_path = Path(module_file).parent new_import_files = [str(module_path / m) for m in new_imports] new_import_files = [f for f in new_import_files if f not in all_relative_imports] files_to_check = [f"{f}.py" for f in new_import_files] no_change = len(new_import_files) == 0 all_relative_imports.extend(files_to_check) return all_relative_imports def check_imports(filename): """ Check if the current Python environment contains all the libraries that are imported in a file. """ with open(filename, "r", encoding="utf-8") as f: content = f.read() # Imports of the form `import xxx` imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) # Imports of the form `from xxx import yyy` imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) # Only keep the top-level module imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] # Unique-ify and test we got them all imports = list(set(imports)) missing_packages = [] for imp in imports: try: importlib.import_module(imp) except ImportError: missing_packages.append(imp) if len(missing_packages) > 0: raise ImportError( "This modeling file requires the following packages that were not found in your environment: " f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" ) return get_relative_imports(filename) def get_class_in_module(class_name, module_path): """ Import a module on the cache directory for modules and extract a class from it. """ module_path = module_path.replace(os.path.sep, ".") module = importlib.import_module(module_path) if class_name is None: return find_pipeline_class(module) return getattr(module, class_name) def find_pipeline_class(loaded_module): """ Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class inheriting from `DiffusionPipeline`. """ from ..pipelines import DiffusionPipeline cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass)) pipeline_class = None for cls_name, cls in cls_members.items(): if ( cls_name != DiffusionPipeline.__name__ and issubclass(cls, DiffusionPipeline) and cls.__module__.split(".")[0] != "diffusers" ): if pipeline_class is not None: raise ValueError( f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:" f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in" f" {loaded_module}." ) pipeline_class = cls return pipeline_class def get_cached_module_file( pretrained_model_name_or_path: Union[str, os.PathLike], module_file: str, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, resume_download: bool = False, proxies: Optional[Dict[str, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, ): """ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached Transformers module. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained model configuration hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - a path to a *directory* containing a configuration file saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. module_file (`str`): The name of the module file containing the class to look for. cache_dir (`str` or `os.PathLike`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, will only try to load the tokenizer configuration from local files. You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Returns: `str`: The path to the module inside the cache. """ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. pretrained_model_name_or_path = str(pretrained_model_name_or_path) module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) if os.path.isfile(module_file_or_url): resolved_module_file = module_file_or_url submodule = "local" elif pretrained_model_name_or_path.count("/") == 0: available_versions = get_diffusers_versions() # cut ".dev0" latest_version = "v" + ".".join(__version__.split(".")[:3]) # retrieve github version that matches if revision is None: revision = latest_version if latest_version[1:] in available_versions else "main" logger.info(f"Defaulting to latest_version: {revision}.") elif revision in available_versions: revision = f"v{revision}" elif revision == "main": revision = revision else: raise ValueError( f"`custom_revision`: {revision} does not exist. Please make sure to choose one of" f" {', '.join(available_versions + ['main'])}." ) # community pipeline on GitHub github_url = COMMUNITY_PIPELINES_URL.format(revision=revision, pipeline=pretrained_model_name_or_path) try: resolved_module_file = cached_download( github_url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=False, ) submodule = "git" module_file = pretrained_model_name_or_path + ".py" except EnvironmentError: logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") raise else: try: # Load from URL or cache if already cached resolved_module_file = hf_hub_download( pretrained_model_name_or_path, module_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, ) submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) except EnvironmentError: logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") raise # Check we have all the requirements in our environment modules_needed = check_imports(resolved_module_file) # Now we move the module inside our cached dynamic modules. full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule create_dynamic_module(full_submodule) submodule_path = Path(HF_MODULES_CACHE) / full_submodule if submodule == "local" or submodule == "git": # We always copy local files (we could hash the file to see if there was a change, and give them the name of # that hash, to only copy when there is a modification but it seems overkill for now). # The only reason we do the copy is to avoid putting too many folders in sys.path. shutil.copy(resolved_module_file, submodule_path / module_file) for module_needed in modules_needed: module_needed = f"{module_needed}.py" shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) else: # Get the commit hash # TODO: we will get this info in the etag soon, so retrieve it from there and not here. if isinstance(use_auth_token, str): token = use_auth_token elif use_auth_token is True: token = HfFolder.get_token() else: token = None commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the # benefit of versioning. submodule_path = submodule_path / commit_hash full_submodule = full_submodule + os.path.sep + commit_hash create_dynamic_module(full_submodule) if not (submodule_path / module_file).exists(): shutil.copy(resolved_module_file, submodule_path / module_file) # Make sure we also have every file with relative for module_needed in modules_needed: if not (submodule_path / module_needed).exists(): get_cached_module_file( pretrained_model_name_or_path, f"{module_needed}.py", cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, use_auth_token=use_auth_token, revision=revision, local_files_only=local_files_only, ) return os.path.join(full_submodule, module_file) def get_class_from_dynamic_module( pretrained_model_name_or_path: Union[str, os.PathLike], module_file: str, class_name: Optional[str] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, resume_download: bool = False, proxies: Optional[Dict[str, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, **kwargs, ): """ Extracts a class from a module file, present in the local folder or repository of a model. Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should therefore only be called on trusted repos. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained model configuration hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - a path to a *directory* containing a configuration file saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. module_file (`str`): The name of the module file containing the class to look for. class_name (`str`): The name of the class to import in the module. cache_dir (`str` or `os.PathLike`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. use_auth_token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, will only try to load the tokenizer configuration from local files. You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Returns: `type`: The class, dynamically imported from the module. Examples: ```python # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this # module. cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") ```""" # And lastly we get the class inside our newly created module final_module = get_cached_module_file( pretrained_model_name_or_path, module_file, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, use_auth_token=use_auth_token, revision=revision, local_files_only=local_files_only, ) return get_class_in_module(class_name, final_module.replace(".py", "")) ================================================ FILE: 4DoF/diffusers/utils/hub_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import sys import traceback import warnings from pathlib import Path from typing import Dict, Optional, Union from uuid import uuid4 from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.utils import ( EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, is_jinja_available, ) from packaging import version from requests import HTTPError from .. import __version__ from .constants import ( DEPRECATED_REVISION_ARGS, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, ) from .import_utils import ( ENV_VARS_TRUE_VALUES, _flax_version, _jax_version, _onnxruntime_version, _torch_version, is_flax_available, is_onnx_available, is_torch_available, ) from .logging import get_logger logger = get_logger(__name__) MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" SESSION_ID = uuid4().hex HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/" def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: """ Formats a user-agent string with basic info about a request. """ ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" if DISABLE_TELEMETRY or HF_HUB_OFFLINE: return ua + "; telemetry/off" if is_torch_available(): ua += f"; torch/{_torch_version}" if is_flax_available(): ua += f"; jax/{_jax_version}" ua += f"; flax/{_flax_version}" if is_onnx_available(): ua += f"; onnxruntime/{_onnxruntime_version}" # CI will set this value to True if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: ua += "; is_ci/true" if isinstance(user_agent, dict): ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) elif isinstance(user_agent, str): ua += "; " + user_agent return ua def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() if organization is None: username = whoami(token)["name"] return f"{username}/{model_id}" else: return f"{organization}/{model_id}" def create_model_card(args, model_name): if not is_jinja_available(): raise ValueError( "Modelcard rendering is based on Jinja templates." " Please make sure to have `jinja` installed before using `create_model_card`." " To install it, please run `pip install Jinja2`." ) if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: return hub_token = args.hub_token if hasattr(args, "hub_token") else None repo_name = get_full_repo_name(model_name, token=hub_token) model_card = ModelCard.from_template( card_data=ModelCardData( # Card metadata object that will be converted to YAML block language="en", license="apache-2.0", library_name="diffusers", tags=[], datasets=args.dataset_name, metrics=[], ), template_path=MODEL_CARD_TEMPLATE_PATH, model_name=model_name, repo_name=repo_name, dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, learning_rate=args.learning_rate, train_batch_size=args.train_batch_size, eval_batch_size=args.eval_batch_size, gradient_accumulation_steps=( args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None ), adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, ema_power=args.ema_power if hasattr(args, "ema_power") else None, ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, mixed_precision=args.mixed_precision, ) card_path = os.path.join(args.output_dir, "README.md") model_card.save(card_path) def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None): """ Extracts the commit hash from a resolved filename toward a cache file. """ if resolved_file is None or commit_hash is not None: return commit_hash resolved_file = str(Path(resolved_file).as_posix()) search = re.search(r"snapshots/([^/]+)/", resolved_file) if search is None: return None commit_hash = search.groups()[0] return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None # Old default cache path, potentially to be migrated. # This logic was more or less taken from `transformers`, with the following differences: # - Diffusers doesn't use custom environment variables to specify the cache path. # - There is no need to migrate the cache format, just move the files to the new location. hf_cache_home = os.path.expanduser( os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) ) old_diffusers_cache = os.path.join(hf_cache_home, "diffusers") def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None: if new_cache_dir is None: new_cache_dir = DIFFUSERS_CACHE if old_cache_dir is None: old_cache_dir = old_diffusers_cache old_cache_dir = Path(old_cache_dir).expanduser() new_cache_dir = Path(new_cache_dir).expanduser() for old_blob_path in old_cache_dir.glob("**/blobs/*"): if old_blob_path.is_file() and not old_blob_path.is_symlink(): new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) new_blob_path.parent.mkdir(parents=True, exist_ok=True) os.replace(old_blob_path, new_blob_path) try: os.symlink(new_blob_path, old_blob_path) except OSError: logger.warning( "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded." ) # At this point, old_cache_dir contains symlinks to the new cache (it can still be used). cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt") if not os.path.isfile(cache_version_file): cache_version = 0 else: with open(cache_version_file) as f: try: cache_version = int(f.read()) except ValueError: cache_version = 0 if cache_version < 1: old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0 if old_cache_is_not_empty: logger.warning( "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your " "existing cached models. This is a one-time operation, you can interrupt it or run it " "later by calling `diffusers.utils.hub_utils.move_cache()`." ) try: move_cache() except Exception as e: trace = "\n".join(traceback.format_tb(e.__traceback__)) logger.error( f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease " "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole " "message and we will do our best to help." ) if cache_version < 1: try: os.makedirs(DIFFUSERS_CACHE, exist_ok=True) with open(cache_version_file, "w") as f: f.write("1") except Exception: logger.warning( f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure " "the directory exists and can be written to." ) def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: splits = weights_name.split(".") splits = splits[:-1] + [variant] + splits[-1:] weights_name = ".".join(splits) return weights_name def _get_model_file( pretrained_model_name_or_path, *, weights_name, subfolder, cache_dir, force_download, proxies, resume_download, local_files_only, use_auth_token, user_agent, revision, commit_hash=None, ): pretrained_model_name_or_path = str(pretrained_model_name_or_path) if os.path.isfile(pretrained_model_name_or_path): return pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): # Load from a PyTorch checkpoint model_file = os.path.join(pretrained_model_name_or_path, weights_name) return model_file elif subfolder is not None and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, weights_name) ): model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) return model_file else: raise EnvironmentError( f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." ) else: # 1. First check if deprecated way of loading from branches is used if ( revision in DEPRECATED_REVISION_ARGS and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) and version.parse(version.parse(__version__).base_version) >= version.parse("0.20.0") ): try: model_file = hf_hub_download( pretrained_model_name_or_path, filename=_add_variant(weights_name, revision), cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, subfolder=subfolder, revision=revision or commit_hash, ) warnings.warn( f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", FutureWarning, ) return model_file except: # noqa: E722 warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", FutureWarning, ) try: # 2. Load model file as usual model_file = hf_hub_download( pretrained_model_name_or_path, filename=weights_name, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, subfolder=subfolder, revision=revision or commit_hash, ) return model_file except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " "login`." ) except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " "this model name. Check the model page at " f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." ) except HTTPError as err: raise EnvironmentError( f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" ) except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" directory containing a file named {weights_name} or" " \nCheckout your internet connection or see how to run the library in" " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a file named {weights_name}" ) ================================================ FILE: 4DoF/diffusers/utils/import_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Import utilities: Utilities related to imports and our lazy inits. """ import importlib.util import operator as op import os import sys from collections import OrderedDict from typing import Union from huggingface_hub.utils import is_jinja_available # noqa: F401 from packaging import version from packaging.version import Version, parse from . import logging # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata logger = logging.get_logger(__name__) # pylint: disable=invalid-name ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper() STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None if _torch_available: try: _torch_version = importlib_metadata.version("torch") logger.info(f"PyTorch version {_torch_version} available.") except importlib_metadata.PackageNotFoundError: _torch_available = False else: logger.info("Disabling PyTorch because USE_TORCH is set") _torch_available = False _tf_version = "N/A" if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: _tf_available = importlib.util.find_spec("tensorflow") is not None if _tf_available: candidates = ( "tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos", "tensorflow-aarch64", ) _tf_version = None # For the metadata, we have to look for both tensorflow and tensorflow-cpu for pkg in candidates: try: _tf_version = importlib_metadata.version(pkg) break except importlib_metadata.PackageNotFoundError: pass _tf_available = _tf_version is not None if _tf_available: if version.parse(_tf_version) < version.parse("2"): logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.") _tf_available = False else: logger.info(f"TensorFlow version {_tf_version} available.") else: logger.info("Disabling Tensorflow because USE_TORCH is set") _tf_available = False _jax_version = "N/A" _flax_version = "N/A" if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None if _flax_available: try: _jax_version = importlib_metadata.version("jax") _flax_version = importlib_metadata.version("flax") logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") except importlib_metadata.PackageNotFoundError: _flax_available = False else: _flax_available = False if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: _safetensors_available = importlib.util.find_spec("safetensors") is not None if _safetensors_available: try: _safetensors_version = importlib_metadata.version("safetensors") logger.info(f"Safetensors version {_safetensors_version} available.") except importlib_metadata.PackageNotFoundError: _safetensors_available = False else: logger.info("Disabling Safetensors because USE_TF is set") _safetensors_available = False _transformers_available = importlib.util.find_spec("transformers") is not None try: _transformers_version = importlib_metadata.version("transformers") logger.debug(f"Successfully imported transformers version {_transformers_version}") except importlib_metadata.PackageNotFoundError: _transformers_available = False _inflect_available = importlib.util.find_spec("inflect") is not None try: _inflect_version = importlib_metadata.version("inflect") logger.debug(f"Successfully imported inflect version {_inflect_version}") except importlib_metadata.PackageNotFoundError: _inflect_available = False _unidecode_available = importlib.util.find_spec("unidecode") is not None try: _unidecode_version = importlib_metadata.version("unidecode") logger.debug(f"Successfully imported unidecode version {_unidecode_version}") except importlib_metadata.PackageNotFoundError: _unidecode_available = False _onnxruntime_version = "N/A" _onnx_available = importlib.util.find_spec("onnxruntime") is not None if _onnx_available: candidates = ( "onnxruntime", "onnxruntime-gpu", "ort_nightly_gpu", "onnxruntime-directml", "onnxruntime-openvino", "ort_nightly_directml", "onnxruntime-rocm", "onnxruntime-training", ) _onnxruntime_version = None # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu for pkg in candidates: try: _onnxruntime_version = importlib_metadata.version(pkg) break except importlib_metadata.PackageNotFoundError: pass _onnx_available = _onnxruntime_version is not None if _onnx_available: logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") # (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. # _opencv_available = importlib.util.find_spec("opencv-python") is not None try: candidates = ( "opencv-python", "opencv-contrib-python", "opencv-python-headless", "opencv-contrib-python-headless", ) _opencv_version = None for pkg in candidates: try: _opencv_version = importlib_metadata.version(pkg) break except importlib_metadata.PackageNotFoundError: pass _opencv_available = _opencv_version is not None if _opencv_available: logger.debug(f"Successfully imported cv2 version {_opencv_version}") except importlib_metadata.PackageNotFoundError: _opencv_available = False _scipy_available = importlib.util.find_spec("scipy") is not None try: _scipy_version = importlib_metadata.version("scipy") logger.debug(f"Successfully imported scipy version {_scipy_version}") except importlib_metadata.PackageNotFoundError: _scipy_available = False _librosa_available = importlib.util.find_spec("librosa") is not None try: _librosa_version = importlib_metadata.version("librosa") logger.debug(f"Successfully imported librosa version {_librosa_version}") except importlib_metadata.PackageNotFoundError: _librosa_available = False _accelerate_available = importlib.util.find_spec("accelerate") is not None try: _accelerate_version = importlib_metadata.version("accelerate") logger.debug(f"Successfully imported accelerate version {_accelerate_version}") except importlib_metadata.PackageNotFoundError: _accelerate_available = False _xformers_available = importlib.util.find_spec("xformers") is not None try: _xformers_version = importlib_metadata.version("xformers") if _torch_available: import torch if version.Version(torch.__version__) < version.Version("1.12"): raise ValueError("PyTorch should be >= 1.12") logger.debug(f"Successfully imported xformers version {_xformers_version}") except importlib_metadata.PackageNotFoundError: _xformers_available = False _k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None try: _k_diffusion_version = importlib_metadata.version("k_diffusion") logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") except importlib_metadata.PackageNotFoundError: _k_diffusion_available = False _note_seq_available = importlib.util.find_spec("note_seq") is not None try: _note_seq_version = importlib_metadata.version("note_seq") logger.debug(f"Successfully imported note-seq version {_note_seq_version}") except importlib_metadata.PackageNotFoundError: _note_seq_available = False _wandb_available = importlib.util.find_spec("wandb") is not None try: _wandb_version = importlib_metadata.version("wandb") logger.debug(f"Successfully imported wandb version {_wandb_version }") except importlib_metadata.PackageNotFoundError: _wandb_available = False _omegaconf_available = importlib.util.find_spec("omegaconf") is not None try: _omegaconf_version = importlib_metadata.version("omegaconf") logger.debug(f"Successfully imported omegaconf version {_omegaconf_version}") except importlib_metadata.PackageNotFoundError: _omegaconf_available = False _tensorboard_available = importlib.util.find_spec("tensorboard") try: _tensorboard_version = importlib_metadata.version("tensorboard") logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") except importlib_metadata.PackageNotFoundError: _tensorboard_available = False _compel_available = importlib.util.find_spec("compel") try: _compel_version = importlib_metadata.version("compel") logger.debug(f"Successfully imported compel version {_compel_version}") except importlib_metadata.PackageNotFoundError: _compel_available = False _ftfy_available = importlib.util.find_spec("ftfy") is not None try: _ftfy_version = importlib_metadata.version("ftfy") logger.debug(f"Successfully imported ftfy version {_ftfy_version}") except importlib_metadata.PackageNotFoundError: _ftfy_available = False _bs4_available = importlib.util.find_spec("bs4") is not None try: # importlib metadata under different name _bs4_version = importlib_metadata.version("beautifulsoup4") logger.debug(f"Successfully imported ftfy version {_bs4_version}") except importlib_metadata.PackageNotFoundError: _bs4_available = False _torchsde_available = importlib.util.find_spec("torchsde") is not None try: _torchsde_version = importlib_metadata.version("torchsde") logger.debug(f"Successfully imported torchsde version {_torchsde_version}") except importlib_metadata.PackageNotFoundError: _torchsde_available = False _invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None try: _invisible_watermark_version = importlib_metadata.version("invisible-watermark") logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}") except importlib_metadata.PackageNotFoundError: _invisible_watermark_available = False def is_torch_available(): return _torch_available def is_safetensors_available(): return _safetensors_available def is_tf_available(): return _tf_available def is_flax_available(): return _flax_available def is_transformers_available(): return _transformers_available def is_inflect_available(): return _inflect_available def is_unidecode_available(): return _unidecode_available def is_onnx_available(): return _onnx_available def is_opencv_available(): return _opencv_available def is_scipy_available(): return _scipy_available def is_librosa_available(): return _librosa_available def is_xformers_available(): return _xformers_available def is_accelerate_available(): return _accelerate_available def is_k_diffusion_available(): return _k_diffusion_available def is_note_seq_available(): return _note_seq_available def is_wandb_available(): return _wandb_available def is_omegaconf_available(): return _omegaconf_available def is_tensorboard_available(): return _tensorboard_available def is_compel_available(): return _compel_available def is_ftfy_available(): return _ftfy_available def is_bs4_available(): return _bs4_available def is_torchsde_available(): return _torchsde_available def is_invisible_watermark_available(): return _invisible_watermark_available # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the installation page: https://github.com/google/flax and follow the ones that match your environment. """ # docstyle-ignore INFLECT_IMPORT_ERROR = """ {0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install inflect` """ # docstyle-ignore PYTORCH_IMPORT_ERROR = """ {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. """ # docstyle-ignore ONNX_IMPORT_ERROR = """ {0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip install onnxruntime` """ # docstyle-ignore OPENCV_IMPORT_ERROR = """ {0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip install opencv-python` """ # docstyle-ignore SCIPY_IMPORT_ERROR = """ {0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install scipy` """ # docstyle-ignore LIBROSA_IMPORT_ERROR = """ {0} requires the librosa library but it was not found in your environment. Checkout the instructions on the installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment. """ # docstyle-ignore TRANSFORMERS_IMPORT_ERROR = """ {0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip install transformers` """ # docstyle-ignore UNIDECODE_IMPORT_ERROR = """ {0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install Unidecode` """ # docstyle-ignore K_DIFFUSION_IMPORT_ERROR = """ {0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip install k-diffusion` """ # docstyle-ignore NOTE_SEQ_IMPORT_ERROR = """ {0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip install note-seq` """ # docstyle-ignore WANDB_IMPORT_ERROR = """ {0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip install wandb` """ # docstyle-ignore OMEGACONF_IMPORT_ERROR = """ {0} requires the omegaconf library but it was not found in your environment. You can install it with pip: `pip install omegaconf` """ # docstyle-ignore TENSORBOARD_IMPORT_ERROR = """ {0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip install tensorboard` """ # docstyle-ignore COMPEL_IMPORT_ERROR = """ {0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel` """ # docstyle-ignore BS4_IMPORT_ERROR = """ {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: `pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. """ # docstyle-ignore FTFY_IMPORT_ERROR = """ {0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ # docstyle-ignore TORCHSDE_IMPORT_ERROR = """ {0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` """ # docstyle-ignore INVISIBLE_WATERMARK_IMPORT_ERROR = """ {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=2.0` """ BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), ] ) def requires_backends(obj, backends): if not isinstance(backends, (list, tuple)): backends = [backends] name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ checks = (BACKENDS_MAPPING[backend] for backend in backends) failed = [msg.format(name) for available, msg in checks if not available()] if failed: raise ImportError("".join(failed)) if name in [ "VersatileDiffusionTextToImagePipeline", "VersatileDiffusionPipeline", "VersatileDiffusionDualGuidedPipeline", "StableDiffusionImageVariationPipeline", "UnCLIPPipeline", ] and is_transformers_version("<", "4.25.0"): raise ImportError( f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install" " --upgrade transformers \n```" ) if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version( "<", "4.26.0" ): raise ImportError( f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install" " --upgrade transformers \n```" ) class DummyObject(type): """ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by `requires_backend` each time a user tries to access any method of that class. """ def __getattr__(cls, key): if key.startswith("_"): return super().__getattr__(cls, key) requires_backends(cls, cls._backends) # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): """ Args: Compares a library version to some requirement using a given operation. library_or_version (`str` or `packaging.version.Version`): A library name or a version to check. operation (`str`): A string representation of an operator, such as `">"` or `"<="`. requirement_version (`str`): The version to compare the library version against """ if operation not in STR_OPERATION_TO_FUNC.keys(): raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") operation = STR_OPERATION_TO_FUNC[operation] if isinstance(library_or_version, str): library_or_version = parse(importlib_metadata.version(library_or_version)) return operation(library_or_version, parse(requirement_version)) # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 def is_torch_version(operation: str, version: str): """ Args: Compares the current PyTorch version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A string version of PyTorch """ return compare_versions(parse(_torch_version), operation, version) def is_transformers_version(operation: str, version: str): """ Args: Compares the current Transformers version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A version string """ if not _transformers_available: return False return compare_versions(parse(_transformers_version), operation, version) def is_accelerate_version(operation: str, version: str): """ Args: Compares the current Accelerate version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A version string """ if not _accelerate_available: return False return compare_versions(parse(_accelerate_version), operation, version) def is_k_diffusion_version(operation: str, version: str): """ Args: Compares the current k-diffusion version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A version string """ if not _k_diffusion_available: return False return compare_versions(parse(_k_diffusion_version), operation, version) class OptionalDependencyNotAvailable(BaseException): """An error indicating that an optional dependency of Diffusers was not found in the environment.""" ================================================ FILE: 4DoF/diffusers/utils/logging.py ================================================ # coding=utf-8 # Copyright 2023 Optuna, Hugging Face # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Logging utilities.""" import logging import os import sys import threading from logging import ( CRITICAL, # NOQA DEBUG, # NOQA ERROR, # NOQA FATAL, # NOQA INFO, # NOQA NOTSET, # NOQA WARN, # NOQA WARNING, # NOQA ) from typing import Optional from tqdm import auto as tqdm_lib _lock = threading.Lock() _default_handler: Optional[logging.Handler] = None log_levels = { "debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING, "error": logging.ERROR, "critical": logging.CRITICAL, } _default_log_level = logging.WARNING _tqdm_active = True def _get_default_logging_level(): """ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is not - fall back to `_default_log_level` """ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) if env_level_str: if env_level_str in log_levels: return log_levels[env_level_str] else: logging.getLogger().warning( f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " f"has to be one of: { ', '.join(log_levels.keys()) }" ) return _default_log_level def _get_library_name() -> str: return __name__.split(".")[0] def _get_library_root_logger() -> logging.Logger: return logging.getLogger(_get_library_name()) def _configure_library_root_logger() -> None: global _default_handler with _lock: if _default_handler: # This library has already configured the library root logger. return _default_handler = logging.StreamHandler() # Set sys.stderr as stream. _default_handler.flush = sys.stderr.flush # Apply our default configuration to the library root logger. library_root_logger = _get_library_root_logger() library_root_logger.addHandler(_default_handler) library_root_logger.setLevel(_get_default_logging_level()) library_root_logger.propagate = False def _reset_library_root_logger() -> None: global _default_handler with _lock: if not _default_handler: return library_root_logger = _get_library_root_logger() library_root_logger.removeHandler(_default_handler) library_root_logger.setLevel(logging.NOTSET) _default_handler = None def get_log_levels_dict(): return log_levels def get_logger(name: Optional[str] = None) -> logging.Logger: """ Return a logger with the specified name. This function is not supposed to be directly accessed unless you are writing a custom diffusers module. """ if name is None: name = _get_library_name() _configure_library_root_logger() return logging.getLogger(name) def get_verbosity() -> int: """ Return the current level for the 🤗 Diffusers' root logger as an `int`. Returns: `int`: Logging level integers which can be one of: - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - `40`: `diffusers.logging.ERROR` - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN` - `20`: `diffusers.logging.INFO` - `10`: `diffusers.logging.DEBUG` """ _configure_library_root_logger() return _get_library_root_logger().getEffectiveLevel() def set_verbosity(verbosity: int) -> None: """ Set the verbosity level for the 🤗 Diffusers' root logger. Args: verbosity (`int`): Logging level which can be one of: - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - `diffusers.logging.ERROR` - `diffusers.logging.WARNING` or `diffusers.logging.WARN` - `diffusers.logging.INFO` - `diffusers.logging.DEBUG` """ _configure_library_root_logger() _get_library_root_logger().setLevel(verbosity) def set_verbosity_info(): """Set the verbosity to the `INFO` level.""" return set_verbosity(INFO) def set_verbosity_warning(): """Set the verbosity to the `WARNING` level.""" return set_verbosity(WARNING) def set_verbosity_debug(): """Set the verbosity to the `DEBUG` level.""" return set_verbosity(DEBUG) def set_verbosity_error(): """Set the verbosity to the `ERROR` level.""" return set_verbosity(ERROR) def disable_default_handler() -> None: """Disable the default handler of the 🤗 Diffusers' root logger.""" _configure_library_root_logger() assert _default_handler is not None _get_library_root_logger().removeHandler(_default_handler) def enable_default_handler() -> None: """Enable the default handler of the 🤗 Diffusers' root logger.""" _configure_library_root_logger() assert _default_handler is not None _get_library_root_logger().addHandler(_default_handler) def add_handler(handler: logging.Handler) -> None: """adds a handler to the HuggingFace Diffusers' root logger.""" _configure_library_root_logger() assert handler is not None _get_library_root_logger().addHandler(handler) def remove_handler(handler: logging.Handler) -> None: """removes given handler from the HuggingFace Diffusers' root logger.""" _configure_library_root_logger() assert handler is not None and handler not in _get_library_root_logger().handlers _get_library_root_logger().removeHandler(handler) def disable_propagation() -> None: """ Disable propagation of the library log outputs. Note that log propagation is disabled by default. """ _configure_library_root_logger() _get_library_root_logger().propagate = False def enable_propagation() -> None: """ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent double logging if the root logger has been configured. """ _configure_library_root_logger() _get_library_root_logger().propagate = True def enable_explicit_format() -> None: """ Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows: ``` [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE ``` All handlers currently bound to the root logger are affected by this method. """ handlers = _get_library_root_logger().handlers for handler in handlers: formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") handler.setFormatter(formatter) def reset_format() -> None: """ Resets the formatting for 🤗 Diffusers' loggers. All handlers currently bound to the root logger are affected by this method. """ handlers = _get_library_root_logger().handlers for handler in handlers: handler.setFormatter(None) def warning_advice(self, *args, **kwargs): """ This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this warning will not be printed """ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) if no_advisory_warnings: return self.warning(*args, **kwargs) logging.Logger.warning_advice = warning_advice class EmptyTqdm: """Dummy tqdm which doesn't do anything.""" def __init__(self, *args, **kwargs): # pylint: disable=unused-argument self._iterator = args[0] if args else None def __iter__(self): return iter(self._iterator) def __getattr__(self, _): """Return empty function.""" def empty_fn(*args, **kwargs): # pylint: disable=unused-argument return return empty_fn def __enter__(self): return self def __exit__(self, type_, value, traceback): return class _tqdm_cls: def __call__(self, *args, **kwargs): if _tqdm_active: return tqdm_lib.tqdm(*args, **kwargs) else: return EmptyTqdm(*args, **kwargs) def set_lock(self, *args, **kwargs): self._lock = None if _tqdm_active: return tqdm_lib.tqdm.set_lock(*args, **kwargs) def get_lock(self): if _tqdm_active: return tqdm_lib.tqdm.get_lock() tqdm = _tqdm_cls() def is_progress_bar_enabled() -> bool: """Return a boolean indicating whether tqdm progress bars are enabled.""" global _tqdm_active return bool(_tqdm_active) def enable_progress_bar(): """Enable tqdm progress bar.""" global _tqdm_active _tqdm_active = True def disable_progress_bar(): """Disable tqdm progress bar.""" global _tqdm_active _tqdm_active = False ================================================ FILE: 4DoF/diffusers/utils/model_card_template.md ================================================ --- {{ card_data }} --- # {{ model_name | default("Diffusion Model") }} ## Model description This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library on the `{{ dataset_name }}` dataset. ## Intended uses & limitations #### How to use ```python # TODO: add an example code snippet for running this diffusion pipeline ``` #### Limitations and bias [TODO: provide examples of latent issues and potential remediations] ## Training data [TODO: describe the data used to train the model] ### Training hyperparameters The following hyperparameters were used during training: - learning_rate: {{ learning_rate }} - train_batch_size: {{ train_batch_size }} - eval_batch_size: {{ eval_batch_size }} - gradient_accumulation_steps: {{ gradient_accumulation_steps }} - optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} - lr_scheduler: {{ lr_scheduler }} - lr_warmup_steps: {{ lr_warmup_steps }} - ema_inv_gamma: {{ ema_inv_gamma }} - ema_inv_gamma: {{ ema_power }} - ema_inv_gamma: {{ ema_max_decay }} - mixed_precision: {{ mixed_precision }} ### Training results 📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) ================================================ FILE: 4DoF/diffusers/utils/outputs.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Generic utilities """ from collections import OrderedDict from dataclasses import fields from typing import Any, Tuple import numpy as np from .import_utils import is_torch_available def is_tensor(x): """ Tests if `x` is a `torch.Tensor` or `np.ndarray`. """ if is_torch_available(): import torch if isinstance(x, torch.Tensor): return True return isinstance(x, np.ndarray) class BaseOutput(OrderedDict): """ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular Python dictionary. You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple first. """ def __post_init__(self): class_fields = fields(self) # Safety and consistency checks if not len(class_fields): raise ValueError(f"{self.__class__.__name__} has no fields.") first_field = getattr(self, class_fields[0].name) other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) if other_fields_are_none and isinstance(first_field, dict): for key, value in first_field.items(): self[key] = value else: for field in class_fields: v = getattr(self, field.name) if v is not None: self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") def setdefault(self, *args, **kwargs): raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") def pop(self, *args, **kwargs): raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") def update(self, *args, **kwargs): raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") def __getitem__(self, k): if isinstance(k, str): inner_dict = dict(self.items()) return inner_dict[k] else: return self.to_tuple()[k] def __setattr__(self, name, value): if name in self.keys() and value is not None: # Don't call self.__setitem__ to avoid recursion errors super().__setitem__(name, value) super().__setattr__(name, value) def __setitem__(self, key, value): # Will raise a KeyException if needed super().__setitem__(key, value) # Don't call self.__setattr__ to avoid recursion errors super().__setattr__(key, value) def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not `None`. """ return tuple(self[k] for k in self.keys()) ================================================ FILE: 4DoF/diffusers/utils/pil_utils.py ================================================ import PIL.Image import PIL.ImageOps from packaging import version from PIL import Image if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): PIL_INTERPOLATION = { "linear": PIL.Image.Resampling.BILINEAR, "bilinear": PIL.Image.Resampling.BILINEAR, "bicubic": PIL.Image.Resampling.BICUBIC, "lanczos": PIL.Image.Resampling.LANCZOS, "nearest": PIL.Image.Resampling.NEAREST, } else: PIL_INTERPOLATION = { "linear": PIL.Image.LINEAR, "bilinear": PIL.Image.BILINEAR, "bicubic": PIL.Image.BICUBIC, "lanczos": PIL.Image.LANCZOS, "nearest": PIL.Image.NEAREST, } def pt_to_pil(images): """ Convert a torch image to a PIL image. """ images = (images / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).float().numpy() images = numpy_to_pil(images) return images def numpy_to_pil(images): """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images ================================================ FILE: 4DoF/diffusers/utils/testing_utils.py ================================================ import inspect import logging import multiprocessing import os import random import re import tempfile import unittest import urllib.parse from distutils.util import strtobool from io import BytesIO, StringIO from pathlib import Path from typing import List, Optional, Union import numpy as np import PIL.Image import PIL.ImageOps import requests from packaging import version from .import_utils import ( BACKENDS_MAPPING, is_compel_available, is_flax_available, is_note_seq_available, is_onnx_available, is_opencv_available, is_torch_available, is_torch_version, is_torchsde_available, ) from .logging import get_logger global_rng = random.Random() logger = get_logger(__name__) if is_torch_available(): import torch if "DIFFUSERS_TEST_DEVICE" in os.environ: torch_device = os.environ["DIFFUSERS_TEST_DEVICE"] available_backends = ["cuda", "cpu", "mps"] if torch_device not in available_backends: raise ValueError( f"unknown torch backend for diffusers tests: {torch_device}. Available backends are:" f" {available_backends}" ) logger.info(f"torch_device overrode to {torch_device}") else: torch_device = "cuda" if torch.cuda.is_available() else "cpu" is_torch_higher_equal_than_1_12 = version.parse( version.parse(torch.__version__).base_version ) >= version.parse("1.12") if is_torch_higher_equal_than_1_12: # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details mps_backend_registered = hasattr(torch.backends, "mps") torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device def torch_all_close(a, b, *args, **kwargs): if not is_torch_available(): raise ValueError("PyTorch needs to be installed to use this function.") if not torch.allclose(a, b, *args, **kwargs): assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}." return True def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_name="expected_slice"): test_name = os.environ.get("PYTEST_CURRENT_TEST") if not torch.is_tensor(tensor): tensor = torch.from_numpy(tensor) tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "") # format is usually: # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161]) output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array") test_file, test_class, test_fn = test_name.split("::") test_fn = test_fn.split()[0] with open(filename, "a") as f: print(";".join([test_file, test_class, test_fn, output_str]), file=f) def get_tests_dir(append_path=None): """ Args: append_path: optional path to append to the tests dir path Return: The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is joined after the `tests` dir the former is provided. """ # this function caller's __file__ caller__file__ = inspect.stack()[1][1] tests_dir = os.path.abspath(os.path.dirname(caller__file__)) while not tests_dir.endswith("tests"): tests_dir = os.path.dirname(tests_dir) if append_path: return os.path.join(tests_dir, append_path) else: return tests_dir def parse_flag_from_env(key, default=False): try: value = os.environ[key] except KeyError: # KEY isn't set, default to `default`. _value = default else: # KEY is set, convert it to True or False. try: _value = strtobool(value) except ValueError: # More values are supported, but let's keep the message simple. raise ValueError(f"If set, {key} must be yes or no.") return _value _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False) def floats_tensor(shape, scale=1.0, rng=None, name=None): """Creates a random float32 tensor""" if rng is None: rng = global_rng total_dims = 1 for dim in shape: total_dims *= dim values = [] for _ in range(total_dims): values.append(rng.random() * scale) return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() def slow(test_case): """ Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) def nightly(test_case): """ Decorator marking a test that runs nightly in the diffusers CI. Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. """ return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) def require_torch(test_case): """ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. """ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) def require_torch_2(test_case): """ Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. """ return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( test_case ) def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( test_case ) def skip_mps(test_case): """Decorator marking a test to skip if torch_device is 'mps'""" return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case) def require_flax(test_case): """ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed """ return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) def require_compel(test_case): """ Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when the library is not installed. """ return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case) def require_onnxruntime(test_case): """ Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed. """ return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case) def require_note_seq(test_case): """ Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed. """ return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) def require_torchsde(test_case): """ Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed. """ return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case) def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: if isinstance(arry, str): # local_path = "/home/patrick_huggingface_co/" if local_path is not None: # local_path can be passed to correct images of tests return os.path.join(local_path, "/".join([arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]])) elif arry.startswith("http://") or arry.startswith("https://"): response = requests.get(arry) response.raise_for_status() arry = np.load(BytesIO(response.content)) elif os.path.isfile(arry): arry = np.load(arry) else: raise ValueError( f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path" ) elif isinstance(arry, np.ndarray): pass else: raise ValueError( "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a" " ndarray." ) return arry def load_pt(url: str): response = requests.get(url) response.raise_for_status() arry = torch.load(BytesIO(response.content)) return arry def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: """ Loads `image` to a PIL Image. Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. Returns: `PIL.Image.Image`: A PIL Image. """ if isinstance(image, str): if image.startswith("http://") or image.startswith("https://"): image = PIL.Image.open(requests.get(image, stream=True).raw) elif os.path.isfile(image): image = PIL.Image.open(image) else: raise ValueError( f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" ) elif isinstance(image, PIL.Image.Image): image = image else: raise ValueError( "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." ) image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image def preprocess_image(image: PIL.Image, batch_size: int): w, h = image.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) image = torch.from_numpy(image) return 2.0 * image - 1.0 def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str: if output_gif_path is None: output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name image[0].save( output_gif_path, save_all=True, append_images=image[1:], optimize=False, duration=100, loop=0, ) return output_gif_path def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: if is_opencv_available(): import cv2 else: raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name fourcc = cv2.VideoWriter_fourcc(*"mp4v") h, w, c = video_frames[0].shape video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) for i in range(len(video_frames)): img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) video_writer.write(img) return output_video_path def load_hf_numpy(path) -> np.ndarray: if not path.startswith("http://") or path.startswith("https://"): path = os.path.join( "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path) ) return load_numpy(path) # --- pytest conf functions --- # # to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once pytest_opt_registered = {} def pytest_addoption_shared(parser): """ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` option. """ option = "--make-reports" if option not in pytest_opt_registered: parser.addoption( option, action="store", default=False, help="generate report files. The value of this option is used as a prefix to report names", ) pytest_opt_registered[option] = 1 def pytest_terminal_summary_main(tr, id): """ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current directory. The report files are prefixed with the test suite name. This function emulates --duration and -rA pytest arguments. This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined there. Args: - tr: `terminalreporter` passed from `conftest.py` - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-` plugins and interfere. """ from _pytest.config import create_terminal_writer if not len(id): id = "tests" config = tr.config orig_writer = config.get_terminal_writer() orig_tbstyle = config.option.tbstyle orig_reportchars = tr.reportchars dir = "reports" Path(dir).mkdir(parents=True, exist_ok=True) report_files = { k: f"{dir}/{id}_{k}.txt" for k in [ "durations", "errors", "failures_long", "failures_short", "failures_line", "passes", "stats", "summary_short", "warnings", ] } # custom durations report # note: there is no need to call pytest --durations=XX to get this separate report # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 dlist = [] for replist in tr.stats.values(): for rep in replist: if hasattr(rep, "duration"): dlist.append(rep) if dlist: dlist.sort(key=lambda x: x.duration, reverse=True) with open(report_files["durations"], "w") as f: durations_min = 0.05 # sec f.write("slowest durations\n") for i, rep in enumerate(dlist): if rep.duration < durations_min: f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") break f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") def summary_failures_short(tr): # expecting that the reports were --tb=long (default) so we chop them off here to the last frame reports = tr.getreports("failed") if not reports: return tr.write_sep("=", "FAILURES SHORT STACK") for rep in reports: msg = tr._getfailureheadline(rep) tr.write_sep("_", msg, red=True, bold=True) # chop off the optional leading extra frames, leaving only the last one longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) tr._tw.line(longrepr) # note: not printing out any rep.sections to keep the report short # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. # pytest-instafail does that) # report failures with line/short/long styles config.option.tbstyle = "auto" # full tb with open(report_files["failures_long"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_failures() # config.option.tbstyle = "short" # short tb with open(report_files["failures_short"], "w") as f: tr._tw = create_terminal_writer(config, f) summary_failures_short(tr) config.option.tbstyle = "line" # one line per error with open(report_files["failures_line"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_failures() with open(report_files["errors"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_errors() with open(report_files["warnings"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_warnings() # normal warnings tr.summary_warnings() # final warnings tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) with open(report_files["passes"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_passes() with open(report_files["summary_short"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.short_test_summary() with open(report_files["stats"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_stats() # restore: tr._tw = orig_writer tr.reportchars = orig_reportchars config.option.tbstyle = orig_tbstyle # Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787 def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): """ To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. Args: test_case (`unittest.TestCase`): The test that will run `target_func`. target_func (`Callable`): The function implementing the actual testing logic. inputs (`dict`, *optional*, defaults to `None`): The inputs that will be passed to `target_func` through an (input) queue. timeout (`int`, *optional*, defaults to `None`): The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. """ if timeout is None: timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) start_methohd = "spawn" ctx = multiprocessing.get_context(start_methohd) input_queue = ctx.Queue(1) output_queue = ctx.JoinableQueue(1) # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. input_queue.put(inputs, timeout=timeout) process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) process.start() # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents # the test to exit properly. try: results = output_queue.get(timeout=timeout) output_queue.task_done() except Exception as e: process.terminate() test_case.fail(e) process.join(timeout=timeout) if results["error"] is not None: test_case.fail(f'{results["error"]}') class CaptureLogger: """ Args: Context manager to capture `logging` streams logger: 'logging` logger object Returns: The captured output is available via `self.out` Example: ```python >>> from diffusers.utils import logging >>> from diffusers.testing_utils import CaptureLogger >>> msg = "Testing 1, 2, 3" >>> logging.set_verbosity_info() >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") >>> with CaptureLogger(logger) as cl: ... logger.info(msg) >>> assert cl.out, msg + "\n" ``` """ def __init__(self, logger): self.logger = logger self.io = StringIO() self.sh = logging.StreamHandler(self.io) self.out = "" def __enter__(self): self.logger.addHandler(self.sh) return self def __exit__(self, *exc): self.logger.removeHandler(self.sh) self.out = self.io.getvalue() def __repr__(self): return f"captured: {self.out}\n" def enable_full_determinism(): """ Helper function for reproducible behavior during distributed training. See - https://pytorch.org/docs/stable/notes/randomness.html for pytorch """ # Enable PyTorch deterministic mode. This potentially requires either the environment # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, # depending on the CUDA version, so we set them both here os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" torch.use_deterministic_algorithms(True) # Enable CUDNN deterministic mode torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False def disable_full_determinism(): os.environ["CUDA_LAUNCH_BLOCKING"] = "0" os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" torch.use_deterministic_algorithms(False) ================================================ FILE: 4DoF/diffusers/utils/torch_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch utilities: Utilities related to PyTorch """ from typing import List, Optional, Tuple, Union from . import logging from .import_utils import is_torch_available, is_torch_version if is_torch_available(): import torch logger = logging.get_logger(__name__) # pylint: disable=invalid-name try: from torch._dynamo import allow_in_graph as maybe_allow_in_graph except (ImportError, ModuleNotFoundError): def maybe_allow_in_graph(cls): return cls def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, device: Optional["torch.device"] = None, dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): """A helper function to create random tensors on the desired `device` with the desired `dtype`. When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor is always created on the CPU. """ # device on which tensor is created defaults to device rand_device = device batch_size = shape[0] layout = layout or torch.strided device = device or torch.device("cpu") if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" if device != "mps": logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" f" slighly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents def is_compiled_module(module): """Check whether the module was compiled with torch.compile()""" if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): return False return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) ================================================ FILE: 4DoF/pipeline_zero1to3.py ================================================ # A diffuser version implementation of Zero1to3 (https://github.com/cvlab-columbia/zero123), ICCV 2023 # by Xin Kong import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from packaging import version from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection, ConvNextV2Model, AutoImageProcessor from CN_encoder import CN_encoder from torchvision import transforms import einops from unet_2d_condition import UNet2DConditionModel from diffusers import AutoencoderKL, DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( deprecate, is_accelerate_available, is_accelerate_version, randn_tensor, replace_example_docstring, ) from diffusers.utils import logging from diffusers.configuration_utils import FrozenDict import PIL import numpy as np from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name # todo EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPipeline >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ class CCProjection(ModelMixin, ConfigMixin): @register_to_config def __init__(self, in_channel=772, out_channel=768): super().__init__() self.in_channel = in_channel self.out_channel = out_channel self.projection = torch.nn.Linear(in_channel, out_channel) def forward(self, x): return self.projection(x) class CLIPProjection(ModelMixin, ConfigMixin): @register_to_config def __init__(self, in_channel, out_channel): super().__init__() self.in_channel = in_channel self.out_channel = out_channel # self.post_layernorm = torch.nn.LayerNorm(in_channel) self.visual_projection = torch.nn.Linear(in_channel, out_channel, bias=False) def forward(self, x): # x = self.post_layernorm(x) return self.visual_projection(x) class CNLayernorm(ModelMixin, ConfigMixin): @register_to_config def __init__(self, in_channel, eps): super().__init__() self.in_channel = in_channel self.layernorm = torch.nn.LayerNorm(in_channel, eps=eps) def forward(self, x): return self.layernorm(x) class Zero1to3StableDiffusionPipeline(DiffusionPipeline): r""" Pipeline for single view conditioned novel view generation using Zero1to3. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, image_encoder: CN_encoder, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: AutoImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) self.ConvNextV2_preprocess = transforms.Compose([ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) assert image.ndim == 4, "Image must have 4 dimensions" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 image = image.to(device=device, dtype=dtype) # [-1, 1] -> [0, 1] image = (image + 1.) / 2. image = self.ConvNextV2_preprocess(image) image_embeddings = self.image_encoder(image) # bt, 768, 12, 12 # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeddings) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) else: has_nsfw_concept = None return image, has_nsfw_concept def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_img_latents(self, image, batch_size, dtype, device, generator=None, do_classifier_free_guidance=False, t_in=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) assert image.ndim == 4, "Image must have 4 dimensions" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 image = image.to(device=device, dtype=dtype) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i]) for i in range(batch_size) # sample ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.mode() if batch_size > init_latents.shape[0]: # init_latents = init_latents.repeat(batch_size // init_latents.shape[0], 1, 1, 1) num_images_per_prompt = batch_size // init_latents.shape[0] # duplicate image latents for each generation per prompt, using mps friendly method bs_embed, emb_c, emb_h, emb_w = init_latents.shape init_latents = init_latents.unsqueeze(1) init_latents = init_latents.repeat(1, num_images_per_prompt, 1, 1, 1) init_latents = init_latents.view(bs_embed * num_images_per_prompt, emb_c, emb_h, emb_w) init_latents = torch.cat([torch.zeros_like(init_latents), init_latents]) if do_classifier_free_guidance else init_latents init_latents = init_latents.to(device=device, dtype=dtype) return init_latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, input_imgs: Union[torch.FloatTensor, PIL.Image.Image] = None, prompt_imgs: Union[torch.FloatTensor, PIL.Image.Image] = None, poses: Optional = None, # projections: Union[List] = None, torch_dtype=torch.float32, height: Optional[int] = None, width: Optional[int] = None, T_in: Optional[int] = None, T_out: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 3.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: float = 1.0, ): r""" Function invoked when calling the pipeline for generation. Args: input_imgs (`PIL` or `List[PIL]`, *optional*): The single input image for each 3D object prompt_imgs (`PIL` or `List[PIL]`, *optional*): Same as input_imgs, but will be used later as an image prompt condition, encoded by CLIP feature height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor assert T_out == poses[0].shape[1] # 1. Check inputs. Raise error if not correct # input_image = hint_imgs self.check_inputs(input_imgs, height, width, callback_steps) # 2. Define call parameters if isinstance(input_imgs, PIL.Image.Image): batch_size = 1 elif isinstance(input_imgs, list): batch_size = len(input_imgs) else: batch_size = input_imgs.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input image with pose as prompt # prompt_embeds = self._encode_image_with_pose(prompt_imgs, poses, device, num_images_per_prompt, do_classifier_free_guidance, t_in) prompt_embeds = self._encode_image(prompt_imgs, device, num_images_per_prompt, do_classifier_free_guidance) prompt_embeds = einops.rearrange(prompt_embeds, '(b t) l c -> b (t l) c', t=T_in) if do_classifier_free_guidance: pose_out, pose_in = poses pose_in = torch.cat([pose_in] * 2) pose_out = torch.cat([pose_out] * 2) poses = [pose_out, pose_in] # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables latents = self.prepare_latents( batch_size // T_in * T_out * num_images_per_prompt, 4, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input], dim=1) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, pose=poses).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype) latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Post-processing has_nsfw_concept = None if output_type == "latent": image = latents elif output_type == "pil": # 8. Post-processing image = self.decode_latents(latents) # 10. Convert to PIL image = self.numpy_to_pil(image) else: # 8. Post-processing image = self.decode_latents(latents) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 4DoF/train_eschernet.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import argparse import copy import logging import math import os import shutil from pathlib import Path import einops import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed, DistributedDataParallelKwargs from dataset import ObjaverseData from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from CN_encoder import CN_encoder import diffusers from diffusers import ( AutoencoderKL, DDIMScheduler, DDPMScheduler, # UNet2DConditionModel, ) from unet_2d_condition import UNet2DConditionModel from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline from diffusers.optimization import get_scheduler from diffusers.utils import is_wandb_available from diffusers.utils.import_utils import is_xformers_available from diffusers.training_utils import EMAModel import torchvision import itertools # metrics import cv2 from skimage.metrics import structural_similarity as calculate_ssim import lpips LPIPS = lpips.LPIPS(net='alex', version='0.1') if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. # check_min_version("0.19.0.dev0") logger = get_logger(__name__) def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid @torch.no_grad() def log_validation(validation_dataloader, vae, image_encoder, feature_extractor, unet, args, accelerator, weight_dtype, split="val"): logger.info("Running {} validation... ".format(split)) scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") pipeline = Zero1to3StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=accelerator.unwrap_model(vae).eval(), image_encoder=accelerator.unwrap_model(image_encoder).eval(), feature_extractor=feature_extractor, unet=accelerator.unwrap_model(unet).eval(), scheduler=scheduler, safety_checker=None, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() if args.seed is None: generator = None else: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) image_logs = [] val_lpips = 0 val_ssim = 0 val_psnr = 0 val_loss = 0 val_num = 0 T_out = args.T_out # fix to be 1? for T_in_val in [1, args.T_in_val//2, args.T_in_val]: # eval different number of given views for valid_step, batch in tqdm(enumerate(validation_dataloader)): if args.num_validation_batches is not None and valid_step >= args.num_validation_batches: break T_in = T_in_val gt_image = batch["image_target"].to(dtype=weight_dtype) input_image = batch["image_input"].to(dtype=weight_dtype)[:, :T_in] pose_in = batch["pose_in"].to(dtype=weight_dtype)[:, :T_in] # BxTx4 pose_out = batch["pose_out"].to(dtype=weight_dtype) # BxTx4 gt_image = einops.rearrange(gt_image, 'b t c h w -> (b t) c h w', t=T_out) input_image = einops.rearrange(input_image, 'b t c h w -> (b t) c h w', t=T_in) # T_in images = [] h, w = input_image.shape[2:] for _ in range(args.num_validation_images): with torch.autocast("cuda"): image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in], height=h, width=w, T_in=T_in, T_out=pose_out.shape[1], guidance_scale=args.guidance_scale, num_inference_steps=50, generator=generator, output_type="numpy").images pred_image = torch.from_numpy(image * 2. - 1.).permute(0, 3, 1, 2) images.append(pred_image) pred_np = (image * 255).astype(np.uint8) # [0,1] gt_np = (gt_image / 2 + 0.5).clamp(0, 1) gt_np = (gt_np.cpu().permute(0, 2, 3, 1).float().numpy()*255).astype(np.uint8) # for 1 image # pixel loss loss = F.mse_loss(pred_image[0], gt_image[0].cpu()).item() # LPIPS lpips = LPIPS(pred_image[0], gt_image[0].cpu()).item() # [-1, 1] torch tensor # SSIM ssim = calculate_ssim(pred_np[0], gt_np[0], channel_axis=2) # PSNR psnr = cv2.PSNR(gt_np[0], pred_np[0]) val_loss += loss val_lpips += lpips val_ssim += ssim val_psnr += psnr val_num += 1 image_logs.append( {"gt_image": gt_image, "pred_images": images, "input_image": input_image} ) pixel_loss = val_loss / val_num pixel_lpips= val_lpips / val_num pixel_ssim = val_ssim / val_num pixel_psnr = val_psnr / val_num for tracker in accelerator.trackers: if tracker.name == "wandb": # need to use table, wandb doesn't allow more than 108 images assert args.num_validation_images == 2 table = wandb.Table(columns=["Input", "GT", "Pred1", "Pred2"]) for log_id, log in enumerate(image_logs): formatted_images = [[], [], []] # [[input], [gt], [pred]] pred_images = log["pred_images"] # pred input_image = log["input_image"] # input gt_image = log["gt_image"] # GT formatted_images[0].append(wandb.Image(input_image, caption="{}_input".format(log_id))) formatted_images[1].append(wandb.Image(gt_image, caption="{}_gt".format(log_id))) for sample_id, pred_image in enumerate(pred_images): # n_samples pred_image = wandb.Image(pred_image, caption="{}_pred_{}".format(log_id, sample_id)) formatted_images[2].append(pred_image) table.add_data(*formatted_images[0], *formatted_images[1], *formatted_images[2]) tracker.log({split: table, # formatted_images "{}_T{}_pixel_loss".format(split, T_in_val): pixel_loss, "{}_T{}_lpips".format(split, T_in_val): pixel_lpips, "{}_T{}_ssim".format(split, T_in_val): pixel_ssim, "{}_T{}_psnr".format(split, T_in_val): pixel_psnr}) else: logger.warn(f"image logging not implemented for {tracker.name}") # del pipeline # torch.cuda.empty_cache() # after validation, set the pipeline back to training mode unet.train() vae.eval() image_encoder.train() return image_logs def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a Zero123 training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default="lambdalabs/sd-image-variations-diffusers", required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help=( "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" " float32 precision." ), ) parser.add_argument( "--output_dir", type=str, default="eschernet-4dof", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=256, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--T_in", type=int, default=1, help="Number of input views" ) parser.add_argument( "--T_in_val", type=int, default=10, help="Number of input views" ) parser.add_argument( "--T_out", type=int, default=1, help="Number of output views" ) parser.add_argument( "--max_train_steps", type=int, default=100000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--guidance_scale", type=float, default=3.0, help="unconditional guidance scale, if guidance_scale>1.0, do_classifier_free_guidance" ) parser.add_argument( "--conditioning_dropout_prob", type=float, default=0.05, help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800" ) parser.add_argument( "--checkpointing_steps", type=int, default=2000, help=( "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" "instructions." ), ) parser.add_argument( "--checkpoints_total_limit", type=int, default=20, help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--dataloader_num_workers", type=int, default=1, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=0.5, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--report_to", type=str, default="wandb", # log_image currently only for wandb help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--enable_xformers_memory_efficient_attention", default=True, help="Whether or not to use xformers." ) parser.add_argument( "--set_grads_to_none", default=True, help=( "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" " behaviors, so disable this argument if it causes any problems. More info:" " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" ), ) parser.add_argument( "--dataset_name", type=str, default=None, help=( "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," " or to a folder containing files that 🤗 Datasets can understand." ), ) parser.add_argument( "--dataset_config_name", type=str, default=None, help="The config of the Dataset, leave as None if there's only one config.", ) parser.add_argument( "--train_data_dir", type=str, default=None, help=( "A folder containing the training data. Folder contents must follow the structure described in" " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." ), ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument( "--num_validation_images", type=int, default=2, help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", ) parser.add_argument( "--validation_steps", type=int, default=2000, help=( "Run validation every X steps. Validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`" " and logging the images." ), ) parser.add_argument( "--num_validation_batches", type=int, default=20, help=( "Number of batches to use for validation. If `None`, use all batches." ), ) parser.add_argument( "--tracker_project_name", type=str, default="train_zero123_hf", help=( "The `project_name` argument passed to Accelerator.init_trackers for" " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") if args.dataset_name is not None and args.train_data_dir is not None: raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") if args.resolution % 8 != 0: raise ValueError( "`--resolution` must be divisible by 8 for consistently sized encoded images." ) return args ConvNextV2_preprocess = transforms.Compose([ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def _encode_image(feature_extractor, image_encoder, image, device, dtype, do_classifier_free_guidance): # [-1, 1] -> [0, 1] image = (image + 1.) / 2. image = ConvNextV2_preprocess(image) image_embeddings = image_encoder(image) # bt, 768, 12, 12 if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeddings) image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings #.detach() # !we need keep image encoder gradient def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token, private=True ).repo_id # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision) image_encoder = CN_encoder.from_pretrained("facebook/convnextv2-tiny-22k-224") feature_extractor = None vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision) T_in = args.T_in T_in_val = args.T_in_val T_out = args.T_out vae.eval() vae.requires_grad_(False) image_encoder.train() image_encoder.requires_grad_(True) unet.requires_grad_(True) unet.train() # Create EMA for the unet. if args.use_ema: ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warn( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() vae.enable_slicing() else: raise ValueError("xformers is not available. Make sure it is installed correctly") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() # Check that all trainable models are in full precision low_precision_error_string = ( " Please make sure to always have all model weights in full float32 precision when starting training - even if" " doing mixed precision training, copy of the weights should still be float32." ) if accelerator.unwrap_model(unet).dtype != torch.float32: raise ValueError( f"UNet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW optimizer = optimizer_class( [{"params": unet.parameters(), "lr": args.learning_rate}, {"params": image_encoder.parameters(), "lr": args.learning_rate}], betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon ) # print model info, learnable parameters, non-learnable parameters, total parameters, model size, all in billion def print_model_info(model): print("="*20) # print model class name print("model name: ", type(model).__name__) print("learnable parameters(M): ", sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6) print("non-learnable parameters(M): ", sum(p.numel() for p in model.parameters() if not p.requires_grad) / 1e6) print("total parameters(M): ", sum(p.numel() for p in model.parameters()) / 1e6) print("model size(MB): ", sum(p.numel() * p.element_size() for p in model.parameters()) / 1024 / 1024) print_model_info(unet) print_model_info(vae) print_model_info(image_encoder) # Init Dataset image_transforms = torchvision.transforms.Compose( [ torchvision.transforms.Resize((args.resolution, args.resolution)), # 256, 256 transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ] ) train_dataset = ObjaverseData(root_dir=args.train_data_dir, image_transforms=image_transforms, validation=False, T_in=T_in, T_out=T_out) train_log_dataset = ObjaverseData(root_dir=args.train_data_dir, image_transforms=image_transforms, validation=False, T_in=T_in_val, T_out=T_out, fix_sample=True) validation_dataset = ObjaverseData(root_dir=args.train_data_dir, image_transforms=image_transforms, validation=True, T_in=T_in_val, T_out=T_out, fix_sample=True) # for training train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) # for validation set logs validation_dataloader = torch.utils.data.DataLoader( validation_dataset, shuffle=False, batch_size=1, num_workers=1, ) # for training set logs train_log_dataloader = torch.utils.data.DataLoader( train_log_dataset, shuffle=False, batch_size=1, num_workers=1, ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): """Warmup the learning rate""" lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) for param_group in optimizer.param_groups: param_group['lr'] = lr def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): """Decay the learning rate""" lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr for param_group in optimizer.param_groups: param_group['lr'] = lr # Prepare everything with our `accelerator`. unet, image_encoder, optimizer, train_dataloader, validation_dataloader, train_log_dataloader = accelerator.prepare( unet, image_encoder, optimizer, train_dataloader, validation_dataloader, train_log_dataloader ) if args.use_ema: ema_unet.to(accelerator.device) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move vae, image_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = dict(vars(args)) run_name = args.output_dir.split("logs_")[1] accelerator.init_trackers(args.tracker_project_name, config=tracker_config, init_kwargs={"wandb":{"name":run_name}}) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps do_classifier_free_guidance = args.guidance_scale > 1.0 logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" do_classifier_free_guidance = {do_classifier_free_guidance}") logger.info(f" conditioning_dropout_prob = {args.conditioning_dropout_prob}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) for epoch in range(first_epoch, args.num_train_epochs): loss_epoch = 0.0 num_train_elems = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet, image_encoder): gt_image = batch["image_target"].to(dtype=weight_dtype) # BxTx3xHxW gt_image = einops.rearrange(gt_image, 'b t c h w -> (b t) c h w', t=T_out) input_image = batch["image_input"].to(dtype=weight_dtype) # Bx3xHxW input_image = einops.rearrange(input_image, 'b t c h w -> (b t) c h w', t=T_in) pose_in = batch["pose_in"].to(dtype=weight_dtype) # BxTx4 pose_out = batch["pose_out"].to(dtype=weight_dtype) # BxTx4 gt_latents = vae.encode(gt_image).latent_dist.sample().detach() gt_latents = gt_latents * vae.config.scaling_factor # follow zero123, only target image latent is scaled # Sample noise that we'll add to the latents bsz = gt_latents.shape[0] // T_out noise = torch.randn_like(gt_latents) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=gt_latents.device) timesteps = timesteps.long() timesteps = einops.repeat(timesteps, 'b -> (b t)', t=T_out) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(gt_latents.to(dtype=torch.float32), noise.to(dtype=torch.float32), timesteps).to(dtype=gt_latents.dtype) if do_classifier_free_guidance: #support classifier-free guidance, randomly drop out 5% # Conditioning dropout to support classifier-free guidance during inference. For more details # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. random_p = torch.rand(bsz, device=gt_latents.device) # Sample masks for the edit prompts. prompt_mask = random_p < 2 * args.conditioning_dropout_prob prompt_mask = prompt_mask.reshape(bsz, 1, 1, 1) img_prompt_embeds = _encode_image(feature_extractor, image_encoder, input_image, gt_latents.device, gt_latents.dtype, False) # Final text conditioning. img_prompt_embeds = einops.rearrange(img_prompt_embeds, '(b t) l c -> b t l c', t=T_in) null_conditioning = torch.zeros_like(img_prompt_embeds).detach() img_prompt_embeds = torch.where(prompt_mask, null_conditioning, img_prompt_embeds) img_prompt_embeds = einops.rearrange(img_prompt_embeds, 'b t l c -> (b t) l c', t=T_in) prompt_embeds = torch.cat([img_prompt_embeds], dim=-1) else: # Get the image_with_pose embedding for conditioning prompt_embeds = _encode_image(feature_extractor, image_encoder, input_image, gt_latents.device, gt_latents.dtype, False) prompt_embeds = einops.rearrange(prompt_embeds, '(b t) l c -> b (t l) c', t=T_in) # noisy_latents (b T_out) latent_model_input = torch.cat([noisy_latents], dim=1) # Predict the noise residual model_pred = unet( latent_model_input, timesteps, encoder_hidden_states=prompt_embeds, # (bxT_in) l 768 pose=[pose_out, pose_in], # (bxT_in) 4, pose_out - self-attn, pose_in - cross-attn ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(gt_latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = (loss.mean([1, 2, 3])).mean() accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = itertools.chain(unet.parameters(), image_encoder.parameters()) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() # cosine if global_step <= args.lr_warmup_steps: warmup_lr_schedule(optimizer, global_step, args.lr_warmup_steps, 1e-5, args.learning_rate) else: cosine_lr_schedule(optimizer, global_step, args.max_train_steps, args.learning_rate, 1e-5) optimizer.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: ema_unet.step(unet.parameters()) progress_bar.update(1) global_step += 1 if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") # save pipeline # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: pipelines = os.listdir(args.output_dir) pipelines = [d for d in pipelines if d.startswith("pipeline")] pipelines = sorted(pipelines, key=lambda x: int(x.split("-")[1])) # before we save the new pipeline, we need to have at _most_ `checkpoints_total_limit - 1` pipeline if len(pipelines) >= args.checkpoints_total_limit: num_to_remove = len(pipelines) - args.checkpoints_total_limit + 1 removing_pipelines = pipelines[0:num_to_remove] logger.info( f"{len(pipelines)} pipelines already exist, removing {len(removing_pipelines)} pipelines" ) logger.info(f"removing pipelines: {', '.join(removing_pipelines)}") for removing_pipeline in removing_pipelines: removing_pipeline = os.path.join(args.output_dir, removing_pipeline) shutil.rmtree(removing_pipeline) if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) pipeline = Zero1to3StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=accelerator.unwrap_model(vae), image_encoder=accelerator.unwrap_model(image_encoder), feature_extractor=feature_extractor, unet=accelerator.unwrap_model(unet), scheduler=noise_scheduler, safety_checker=None, torch_dtype=torch.float32, ) pipeline_save_path = os.path.join(args.output_dir, f"pipeline-{global_step}") pipeline.save_pretrained(pipeline_save_path) # del pipeline if args.push_to_hub: print("Pushing to the hub ", repo_id) upload_folder( repo_id=repo_id, folder_path=pipeline_save_path, commit_message=global_step, ignore_patterns=["step_*", "epoch_*"], run_as_future=True, ) if args.use_ema: # Switch back to the original UNet parameters. ema_unet.restore(unet.parameters()) if validation_dataloader is not None and global_step % args.validation_steps == 0: if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) image_logs = log_validation( validation_dataloader, vae, image_encoder, feature_extractor, unet, args, accelerator, weight_dtype, 'val', ) if args.use_ema: # Switch back to the original UNet parameters. ema_unet.restore(unet.parameters()) if train_log_dataloader is not None and (global_step % args.validation_steps == 0 or global_step == 1): if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) train_image_logs = log_validation( train_log_dataloader, vae, image_encoder, feature_extractor, unet, args, accelerator, weight_dtype, 'train', ) if args.use_ema: # Switch back to the original UNet parameters. ema_unet.restore(unet.parameters()) loss_epoch += loss.detach().item() num_train_elems += 1 logs = {"loss": loss.detach().item(), "lr": optimizer.param_groups[0]['lr'], "loss_epoch": loss_epoch / num_train_elems, "epoch": epoch} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters()) pipeline = Zero1to3StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=accelerator.unwrap_model(vae), image_encoder=accelerator.unwrap_model(image_encoder), feature_extractor=feature_extractor, unet=unet, scheduler=noise_scheduler, safety_checker=None, torch_dtype=torch.float32, ) pipeline_save_path = os.path.join(args.output_dir, f"pipeline-{global_step}") pipeline.save_pretrained(pipeline_save_path) if args.push_to_hub: upload_folder( repo_id=repo_id, folder_path=pipeline_save_path, commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) accelerator.end_training() if __name__ == "__main__": # torch.multiprocessing.set_sharing_strategy("file_system") args = parse_args() main(args) ================================================ FILE: 4DoF/unet_2d_condition.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.utils import BaseOutput from diffusers.utils import logging from diffusers.models.activations import get_activation from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, UpBlock2D, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] self.up_block_out_channels = [] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.up_blocks.append(up_block) prev_output_channel = output_channel self.up_block_out_channels.append(output_channel) # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) self.block_out_channels = block_out_channels self.reversed_block_out_channels = reversed_block_out_channels @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, pose = None, # (b T_in) 4 class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, posemb=pose, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) # default down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, posemb=pose, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, posemb=pose, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) ================================================ FILE: 6DoF/CN_encoder.py ================================================ from transformers import ConvNextV2Model import torch from typing import Optional import einops class CN_encoder(ConvNextV2Model): def __init__(self, config): super().__init__(config) def forward( self, pixel_values: torch.FloatTensor = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") embedding_output = self.embeddings(pixel_values) encoder_outputs = self.encoder( embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] image_embeddings = einops.rearrange(last_hidden_state, 'b c h w -> b (h w) c') image_embeddings = self.layernorm(image_embeddings) return image_embeddings ================================================ FILE: 6DoF/dataset.py ================================================ import os import math from pathlib import Path import torch import torchvision from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import numpy as np import webdataset as wds from torch.utils.data.distributed import DistributedSampler import matplotlib.pyplot as plt import sys class ObjaverseDataLoader(): def __init__(self, root_dir, batch_size, total_view=12, num_workers=4): self.root_dir = root_dir self.batch_size = batch_size self.num_workers = num_workers self.total_view = total_view image_transforms = [torchvision.transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] self.image_transforms = torchvision.transforms.Compose(image_transforms) def train_dataloader(self): dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False, image_transforms=self.image_transforms) # sampler = DistributedSampler(dataset) return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) # sampler=sampler) def val_dataloader(self): dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True, image_transforms=self.image_transforms) sampler = DistributedSampler(dataset) return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) def get_pose(transformation): # transformation: 4x4 return transformation class ObjaverseData(Dataset): def __init__(self, root_dir='.objaverse/hf-objaverse-v1/views', image_transforms=None, total_view=12, validation=False, T_in=1, T_out=1, fix_sample=False, ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) """ self.root_dir = Path(root_dir) self.total_view = total_view self.T_in = T_in self.T_out = T_out self.fix_sample = fix_sample self.paths = [] # # include all folders # for folder in os.listdir(self.root_dir): # if os.path.isdir(os.path.join(self.root_dir, folder)): # self.paths.append(folder) # load ids from .npy so we have exactly the same ids/order self.paths = np.load("../scripts/obj_ids.npy") # # only use 100K objects for ablation study # self.paths = self.paths[:100000] total_objects = len(self.paths) assert total_objects == 790152, 'total objects %d' % total_objects if validation: self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation else: self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training print('============= length of dataset %d =============' % len(self.paths)) self.tform = image_transforms downscale = 512 / 256. self.fx = 560. / downscale self.fy = 560. / downscale self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3) def __len__(self): return len(self.paths) def get_pose(self, transformation): # transformation: 4x4 return transformation def load_im(self, path, color): ''' replace background pixel with random color in rendering ''' try: img = plt.imread(path) except: print(path) sys.exit() img[img[:, :, -1] == 0.] = color img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) return img def __getitem__(self, index): data = {} total_view = 12 if self.fix_sample: if self.T_out > 1: indexes = range(total_view) index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):]) index_inputs = indexes[1:self.T_in+1] # one overlap identity else: indexes = range(total_view) index_targets = indexes[:self.T_out] index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity else: assert self.T_in + self.T_out <= total_view # training with replace, including identity indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True) index_inputs = indexes[:self.T_in] index_targets = indexes[self.T_in:] filename = os.path.join(self.root_dir, self.paths[index]) color = [1., 1., 1., 1.] try: input_ims = [] target_ims = [] target_Ts = [] cond_Ts = [] for i, index_input in enumerate(index_inputs): input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color)) input_ims.append(input_im) input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input)) cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) for i, index_target in enumerate(index_targets): target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) target_ims.append(target_im) target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) except: print('error loading data ', filename) filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid input_ims = [] target_ims = [] target_Ts = [] cond_Ts = [] # very hacky solution, sorry about this for i, index_input in enumerate(index_inputs): input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color)) input_ims.append(input_im) input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input)) cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) for i, index_target in enumerate(index_targets): target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) target_ims.append(target_im) target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) # stack to batch data['image_input'] = torch.stack(input_ims, dim=0) data['image_target'] = torch.stack(target_ims, dim=0) data['pose_out'] = np.stack(target_Ts) data['pose_out_inv'] = np.linalg.inv(np.stack(target_Ts)).transpose([0, 2, 1]) data['pose_in'] = np.stack(cond_Ts) data['pose_in_inv'] = np.linalg.inv(np.stack(cond_Ts)).transpose([0, 2, 1]) return data def process_im(self, im): im = im.convert("RGB") return self.tform(im) ================================================ FILE: 6DoF/diffusers/__init__.py ================================================ __version__ = "0.18.2" from .configuration_utils import ConfigMixin from .utils import ( OptionalDependencyNotAvailable, is_flax_available, is_inflect_available, is_invisible_watermark_available, is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, is_note_seq_available, is_onnx_available, is_scipy_available, is_torch_available, is_torchsde_available, is_transformers_available, is_transformers_version, is_unidecode_available, logging, ) try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_onnx_objects import * # noqa F403 else: from .pipelines import OnnxRuntimeModel try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: from .models import ( AutoencoderKL, ControlNetModel, ModelMixin, PriorTransformer, T5FilmDecoder, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, UNet3DConditionModel, VQModel, ) from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, get_scheduler, ) from .pipelines import ( AudioPipelineOutput, ConsistencyModelPipeline, DanceDiffusionPipeline, DDIMPipeline, DDPMPipeline, DiffusionPipeline, DiTPipeline, ImagePipelineOutput, KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, ) from .schedulers import ( CMStochasticIterativeScheduler, DDIMInverseScheduler, DDIMParallelScheduler, DDIMScheduler, DDPMParallelScheduler, DDPMScheduler, DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, PNDMScheduler, RePaintScheduler, SchedulerMixin, ScoreSdeVeScheduler, UnCLIPScheduler, UniPCMultistepScheduler, VQDiffusionScheduler, ) from .training_utils import EMAModel try: if not (is_torch_available() and is_scipy_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_scipy_objects import * # noqa F403 else: from .schedulers import LMSDiscreteScheduler try: if not (is_torch_available() and is_torchsde_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_torchsde_objects import * # noqa F403 else: from .schedulers import DPMSolverSDEScheduler try: if not (is_torch_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipelines import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AudioLDMPipeline, CycleDiffusionPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, IFPipeline, IFSuperResolutionPipeline, ImageTextPipelineOutput, KandinskyImg2ImgPipeline, KandinskyInpaintPipeline, KandinskyPipeline, KandinskyPriorPipeline, KandinskyV22ControlnetImg2ImgPipeline, KandinskyV22ControlnetPipeline, KandinskyV22Img2ImgPipeline, KandinskyV22InpaintPipeline, KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorPipeline, LDMTextToImagePipeline, PaintByExamplePipeline, SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, StableDiffusionLDM3DPipeline, StableDiffusionModelEditingPipeline, StableDiffusionPanoramaPipeline, StableDiffusionParadigmsPipeline, StableDiffusionPipeline, StableDiffusionPipelineSafe, StableDiffusionPix2PixZeroPipeline, StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, TextToVideoSDPipeline, TextToVideoZeroPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, VideoToVideoSDPipeline, VQDiffusionPipeline, ) try: if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 else: from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 else: from .pipelines import StableDiffusionKDiffusionPipeline try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 else: from .pipelines import ( OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionPipeline, OnnxStableDiffusionUpscalePipeline, StableDiffusionOnnxPipeline, ) try: if not (is_torch_available() and is_librosa_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_librosa_objects import * # noqa F403 else: from .pipelines import AudioDiffusionPipeline, Mel try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: from .pipelines import SpectrogramDiffusionPipeline try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_flax_objects import * # noqa F403 else: from .models.controlnet_flax import FlaxControlNetModel from .models.modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL from .pipelines import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxDPMSolverMultistepScheduler, FlaxKarrasVeScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, FlaxSchedulerMixin, FlaxScoreSdeVeScheduler, ) try: if not (is_flax_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 else: from .pipelines import ( FlaxStableDiffusionControlNetPipeline, FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, ) try: if not (is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_note_seq_objects import * # noqa F403 else: from .pipelines import MidiProcessor ================================================ FILE: 6DoF/diffusers/commands/__init__.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from argparse import ArgumentParser class BaseDiffusersCLICommand(ABC): @staticmethod @abstractmethod def register_subcommand(parser: ArgumentParser): raise NotImplementedError() @abstractmethod def run(self): raise NotImplementedError() ================================================ FILE: 6DoF/diffusers/commands/diffusers_cli.py ================================================ #!/usr/bin/env python # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from argparse import ArgumentParser from .env import EnvironmentCommand def main(): parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []") commands_parser = parser.add_subparsers(help="diffusers-cli command helpers") # Register commands EnvironmentCommand.register_subcommand(commands_parser) # Let's go args = parser.parse_args() if not hasattr(args, "func"): parser.print_help() exit(1) # Run service = args.func(args) service.run() if __name__ == "__main__": main() ================================================ FILE: 6DoF/diffusers/commands/env.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import platform from argparse import ArgumentParser import huggingface_hub from .. import __version__ as version from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available from . import BaseDiffusersCLICommand def info_command_factory(_): return EnvironmentCommand() class EnvironmentCommand(BaseDiffusersCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): download_parser = parser.add_parser("env") download_parser.set_defaults(func=info_command_factory) def run(self): hub_version = huggingface_hub.__version__ pt_version = "not installed" pt_cuda_available = "NA" if is_torch_available(): import torch pt_version = torch.__version__ pt_cuda_available = torch.cuda.is_available() transformers_version = "not installed" if is_transformers_available(): import transformers transformers_version = transformers.__version__ accelerate_version = "not installed" if is_accelerate_available(): import accelerate accelerate_version = accelerate.__version__ xformers_version = "not installed" if is_xformers_available(): import xformers xformers_version = xformers.__version__ info = { "`diffusers` version": version, "Platform": platform.platform(), "Python version": platform.python_version(), "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", "Huggingface_hub version": hub_version, "Transformers version": transformers_version, "Accelerate version": accelerate_version, "xFormers version": xformers_version, "Using GPU in script?": "", "Using distributed or parallel set-up in script?": "", } print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") print(self.format_dict(info)) return info @staticmethod def format_dict(d): return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" ================================================ FILE: 6DoF/diffusers/configuration_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ ConfigMixin base class and utilities.""" import dataclasses import functools import importlib import inspect import json import os import re from collections import OrderedDict from pathlib import PosixPath from typing import Any, Dict, Tuple, Union import numpy as np from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError from . import __version__ from .utils import ( DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, extract_commit_hash, http_user_agent, logging, ) logger = logging.get_logger(__name__) _re_configuration_file = re.compile(r"config\.(.*)\.json") class FrozenDict(OrderedDict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for key, value in self.items(): setattr(self, key, value) self.__frozen = True def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") def setdefault(self, *args, **kwargs): raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") def pop(self, *args, **kwargs): raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") def update(self, *args, **kwargs): raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") def __setattr__(self, name, value): if hasattr(self, "__frozen") and self.__frozen: raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") super().__setattr__(name, value) def __setitem__(self, name, value): if hasattr(self, "__frozen") and self.__frozen: raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") super().__setitem__(name, value) class ConfigMixin: r""" Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and saving classes that inherit from [`ConfigMixin`]. Class attributes: - **config_name** (`str`) -- A filename under which the config should stored when calling [`~ConfigMixin.save_config`] (should be overridden by parent class). - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be overridden by subclass). - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by subclass). """ config_name = None ignore_for_config = [] has_compatibles = False _deprecated_kwargs = [] def register_to_config(self, **kwargs): if self.config_name is None: raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") # Special case for `kwargs` used in deprecation warning added to schedulers # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # or solve in a more general way. kwargs.pop("kwargs", None) if not hasattr(self, "_internal_dict"): internal_dict = kwargs else: previous_dict = dict(self._internal_dict) internal_dict = {**self._internal_dict, **kwargs} logger.debug(f"Updating config from {previous_dict} to {internal_dict}") self._internal_dict = FrozenDict(internal_dict) def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite: https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module """ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) is_attribute = name in self.__dict__ if is_in_config and not is_attribute: deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False) return self._internal_dict[name] raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the [`~ConfigMixin.from_config`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file is saved (will be created if it does not exist). """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") os.makedirs(save_directory, exist_ok=True) # If we save using the predefined names, we can load using `from_config` output_config_file = os.path.join(save_directory, self.config_name) self.to_json_file(output_config_file) logger.info(f"Configuration saved in {output_config_file}") @classmethod def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): r""" Instantiate a Python class from a config dictionary. Parameters: config (`Dict[str, Any]`): A config dictionary from which the Python class is instantiated. Make sure to only load configuration files of compatible classes. return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether kwargs that are not consumed by the Python class should be returned or not. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it is loaded) and initiate the Python class. `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually overwrite the same named arguments in `config`. Returns: [`ModelMixin`] or [`SchedulerMixin`]: A model or scheduler object instantiated from a config dictionary. Examples: ```python >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler >>> # Download scheduler from huggingface.co and cache. >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") >>> # Instantiate DDIM scheduler class with same config as DDPM >>> scheduler = DDIMScheduler.from_config(scheduler.config) >>> # Instantiate PNDM scheduler class with same config as DDPM >>> scheduler = PNDMScheduler.from_config(scheduler.config) ``` """ # <===== TO BE REMOVED WITH DEPRECATION # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated if "pretrained_model_name_or_path" in kwargs: config = kwargs.pop("pretrained_model_name_or_path") if config is None: raise ValueError("Please make sure to provide a config as the first positional argument.") # ======> if not isinstance(config, dict): deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." if "Scheduler" in cls.__name__: deprecation_message += ( f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" " be removed in v1.0.0." ) elif "Model" in cls.__name__: deprecation_message += ( f"If you were trying to load a model, please use {cls}.load_config(...) followed by" f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" " instead. This functionality will be removed in v1.0.0." ) deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) # Allow dtype to be specified on initialization if "dtype" in unused_kwargs: init_dict["dtype"] = unused_kwargs.pop("dtype") # add possible deprecated kwargs for deprecated_kwarg in cls._deprecated_kwargs: if deprecated_kwarg in unused_kwargs: init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) # make sure to also save config parameters that might be used for compatible classes model.register_to_config(**hidden_dict) # add hidden kwargs of compatible classes to unused_kwargs unused_kwargs = {**unused_kwargs, **hidden_dict} if return_unused_kwargs: return (model, unused_kwargs) else: return model @classmethod def get_config_dict(cls, *args, **kwargs): deprecation_message = ( f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" " removed in version v1.0.0" ) deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) return cls.load_config(*args, **kwargs) @classmethod def load_config( cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, return_commit_hash=False, **kwargs, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: r""" Load a model or scheduler configuration. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with [`~ConfigMixin.save_config`]. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. return_unused_kwargs (`bool`, *optional*, defaults to `False): Whether unused keyword arguments of the config are returned. return_commit_hash (`bool`, *optional*, defaults to `False): Whether the `commit_hash` of the loaded configuration are returned. Returns: `dict`: A dictionary of all the parameters stored in a JSON configuration file. """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) _ = kwargs.pop("mirror", None) subfolder = kwargs.pop("subfolder", None) user_agent = kwargs.pop("user_agent", {}) user_agent = {**user_agent, "file_type": "config"} user_agent = http_user_agent(user_agent) pretrained_model_name_or_path = str(pretrained_model_name_or_path) if cls.config_name is None: raise ValueError( "`self.config_name` is not defined. Note that one should not load a config from " "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" ) if os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): # Load from a PyTorch checkpoint config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) elif subfolder is not None and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) ): config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name) else: raise EnvironmentError( f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." ) else: try: # Load from URL or cache if already cached config_file = hf_hub_download( pretrained_model_name_or_path, filename=cls.config_name, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, subfolder=subfolder, revision=revision, ) except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier" " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a" " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli" " login`." ) except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for" " this model name. Check the model page at" f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}." ) except HTTPError as err: raise EnvironmentError( "There was a specific connection error when trying to load" f" {pretrained_model_name_or_path}:\n{err}" ) except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to" " run the library in offline mode at" " 'https://huggingface.co/docs/diffusers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a {cls.config_name} file" ) try: # Load config dict config_dict = cls._dict_from_json_file(config_file) commit_hash = extract_commit_hash(config_file) except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") if not (return_unused_kwargs or return_commit_hash): return config_dict outputs = (config_dict,) if return_unused_kwargs: outputs += (kwargs,) if return_commit_hash: outputs += (commit_hash,) return outputs @staticmethod def _get_init_keys(cls): return set(dict(inspect.signature(cls.__init__).parameters).keys()) @classmethod def extract_init_dict(cls, config_dict, **kwargs): # Skip keys that were not present in the original config, so default __init__ values were used used_defaults = config_dict.get("_use_default_values", []) config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} # 0. Copy origin config dict original_dict = dict(config_dict.items()) # 1. Retrieve expected config attributes from __init__ signature expected_keys = cls._get_init_keys(cls) expected_keys.remove("self") # remove general kwargs if present in dict if "kwargs" in expected_keys: expected_keys.remove("kwargs") # remove flax internal keys if hasattr(cls, "_flax_internal_args"): for arg in cls._flax_internal_args: expected_keys.remove(arg) # 2. Remove attributes that cannot be expected from expected config attributes # remove keys to be ignored if len(cls.ignore_for_config) > 0: expected_keys = expected_keys - set(cls.ignore_for_config) # load diffusers library to import compatible and original scheduler diffusers_library = importlib.import_module(__name__.split(".")[0]) if cls.has_compatibles: compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] else: compatible_classes = [] expected_keys_comp_cls = set() for c in compatible_classes: expected_keys_c = cls._get_init_keys(c) expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c) expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls) config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls} # remove attributes from orig class that cannot be expected orig_cls_name = config_dict.pop("_class_name", cls.__name__) if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name): orig_cls = getattr(diffusers_library, orig_cls_name) unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig} # remove private attributes config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments init_dict = {} for key in expected_keys: # if config param is passed to kwarg and is present in config dict # it should overwrite existing config dict key if key in kwargs and key in config_dict: config_dict[key] = kwargs.pop(key) if key in kwargs: # overwrite key init_dict[key] = kwargs.pop(key) elif key in config_dict: # use value from config dict init_dict[key] = config_dict.pop(key) # 4. Give nice warning if unexpected values have been passed if len(config_dict) > 0: logger.warning( f"The config attributes {config_dict} were passed to {cls.__name__}, " "but are not expected and will be ignored. Please verify your " f"{cls.config_name} configuration file." ) # 5. Give nice info if config attributes are initiliazed to default because they have not been passed passed_keys = set(init_dict.keys()) if len(expected_keys - passed_keys) > 0: logger.info( f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." ) # 6. Define unused keyword arguments unused_kwargs = {**config_dict, **kwargs} # 7. Define "hidden" config parameters that were saved for compatible classes hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} return init_dict, unused_kwargs, hidden_config_dict @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() return json.loads(text) def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" @property def config(self) -> Dict[str, Any]: """ Returns the config of the class as a frozen dictionary Returns: `Dict[str, Any]`: Config of the class. """ return self._internal_dict def to_json_string(self) -> str: """ Serializes the configuration instance to a JSON string. Returns: `str`: String containing all the attributes that make up the configuration instance in JSON format. """ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} config_dict["_class_name"] = self.__class__.__name__ config_dict["_diffusers_version"] = __version__ def to_json_saveable(value): if isinstance(value, np.ndarray): value = value.tolist() elif isinstance(value, PosixPath): value = str(value) return value config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} # Don't save "_ignore_files" or "_use_default_values" config_dict.pop("_ignore_files", None) config_dict.pop("_use_default_values", None) return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ Save the configuration instance's parameters to a JSON file. Args: json_file_path (`str` or `os.PathLike`): Path to the JSON file to save a configuration instance's parameters. """ with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string()) def register_to_config(init): r""" Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be registered in the config, use the `ignore_for_config` class variable Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! """ @functools.wraps(init) def inner_init(self, *args, **kwargs): # Ignore private kwargs in the init. init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " "not inherit from `ConfigMixin`." ) ignore = getattr(self, "ignore_for_config", []) # Get positional arguments aligned with kwargs new_kwargs = {} signature = inspect.signature(init) parameters = { name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore } for arg, name in zip(args, parameters.keys()): new_kwargs[name] = arg # Then add all kwargs new_kwargs.update( { k: init_kwargs.get(k, default) for k, default in parameters.items() if k not in ignore and k not in new_kwargs } ) # Take note of the parameters that were not present in the loaded config if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) new_kwargs = {**config_init_kwargs, **new_kwargs} getattr(self, "register_to_config")(**new_kwargs) init(self, *args, **init_kwargs) return inner_init def flax_register_to_config(cls): original_init = cls.__init__ @functools.wraps(original_init) def init(self, *args, **kwargs): if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " "not inherit from `ConfigMixin`." ) # Ignore private kwargs in the init. Retrieve all passed attributes init_kwargs = dict(kwargs.items()) # Retrieve default values fields = dataclasses.fields(self) default_kwargs = {} for field in fields: # ignore flax specific attributes if field.name in self._flax_internal_args: continue if type(field.default) == dataclasses._MISSING_TYPE: default_kwargs[field.name] = None else: default_kwargs[field.name] = getattr(self, field.name) # Make sure init_kwargs override default kwargs new_kwargs = {**default_kwargs, **init_kwargs} # dtype should be part of `init_kwargs`, but not `new_kwargs` if "dtype" in new_kwargs: new_kwargs.pop("dtype") # Get positional arguments aligned with kwargs for i, arg in enumerate(args): name = fields[i].name new_kwargs[name] = arg # Take note of the parameters that were not present in the loaded config if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0: new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs)) getattr(self, "register_to_config")(**new_kwargs) original_init(self, *args, **kwargs) cls.__init__ = init return cls ================================================ FILE: 6DoF/diffusers/dependency_versions_check.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from .dependency_versions_table import deps from .utils.versions import require_version, require_version_core # define which module versions we always want to check at run time # (usually the ones defined in `install_requires` in setup.py) # # order specific notes: # - tqdm must be checked before tokenizers pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() if sys.version_info < (3, 7): pkgs_to_check_at_runtime.append("dataclasses") if sys.version_info < (3, 8): pkgs_to_check_at_runtime.append("importlib_metadata") for pkg in pkgs_to_check_at_runtime: if pkg in deps: if pkg == "tokenizers": # must be loaded here, or else tqdm check may fail from .utils import is_tokenizers_available if not is_tokenizers_available(): continue # not required, check version only if installed require_version_core(deps[pkg]) else: raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") def dep_version_check(pkg, hint=None): require_version(deps[pkg], hint) ================================================ FILE: 6DoF/diffusers/dependency_versions_table.py ================================================ # THIS FILE HAS BEEN AUTOGENERATED. To update: # 1. modify the `_deps` dict in setup.py # 2. run `make deps_table_update`` deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", "compel": "compel==0.1.8", "black": "black~=23.1", "datasets": "datasets", "filelock": "filelock", "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", "huggingface-hub": "huggingface-hub>=0.13.2", "requests-mock": "requests-mock==1.10.0", "importlib_metadata": "importlib_metadata", "invisible-watermark": "invisible-watermark", "isort": "isort>=5.5.4", "jax": "jax>=0.2.8,!=0.3.2", "jaxlib": "jaxlib>=0.1.65", "Jinja2": "Jinja2", "k-diffusion": "k-diffusion>=0.0.12", "torchsde": "torchsde", "note_seq": "note_seq", "librosa": "librosa", "numpy": "numpy", "omegaconf": "omegaconf", "parameterized": "parameterized", "protobuf": "protobuf>=3.20.3,<4", "pytest": "pytest", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", "ruff": "ruff>=0.0.241", "safetensors": "safetensors", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "scipy": "scipy", "onnx": "onnx", "regex": "regex!=2019.12.17", "requests": "requests", "tensorboard": "tensorboard", "torch": "torch>=1.4", "torchvision": "torchvision", "transformers": "transformers>=4.25.1", "urllib3": "urllib3<=2.0.0", } ================================================ FILE: 6DoF/diffusers/experimental/__init__.py ================================================ from .rl import ValueGuidedRLPipeline ================================================ FILE: 6DoF/diffusers/experimental/rl/__init__.py ================================================ from .value_guided_sampling import ValueGuidedRLPipeline ================================================ FILE: 6DoF/diffusers/experimental/rl/value_guided_sampling.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch import tqdm from ...models.unet_1d import UNet1DModel from ...pipelines import DiffusionPipeline from ...utils import randn_tensor from ...utils.dummy_pt_objects import DDPMScheduler class ValueGuidedRLPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Pipeline for sampling actions from a diffusion model trained to predict sequences of states. Original implementation inspired by this repository: https://github.com/jannerm/diffuser. Parameters: value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward. unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this application is [`DDPMScheduler`]. env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models. """ def __init__( self, value_function: UNet1DModel, unet: UNet1DModel, scheduler: DDPMScheduler, env, ): super().__init__() self.value_function = value_function self.unet = unet self.scheduler = scheduler self.env = env self.data = env.get_dataset() self.means = {} for key in self.data.keys(): try: self.means[key] = self.data[key].mean() except: # noqa: E722 pass self.stds = {} for key in self.data.keys(): try: self.stds[key] = self.data[key].std() except: # noqa: E722 pass self.state_dim = env.observation_space.shape[0] self.action_dim = env.action_space.shape[0] def normalize(self, x_in, key): return (x_in - self.means[key]) / self.stds[key] def de_normalize(self, x_in, key): return x_in * self.stds[key] + self.means[key] def to_torch(self, x_in): if type(x_in) is dict: return {k: self.to_torch(v) for k, v in x_in.items()} elif torch.is_tensor(x_in): return x_in.to(self.unet.device) return torch.tensor(x_in, device=self.unet.device) def reset_x0(self, x_in, cond, act_dim): for key, val in cond.items(): x_in[:, key, act_dim:] = val.clone() return x_in def run_diffusion(self, x, conditions, n_guide_steps, scale): batch_size = x.shape[0] y = None for i in tqdm.tqdm(self.scheduler.timesteps): # create batch of timesteps to pass into model timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) for _ in range(n_guide_steps): with torch.enable_grad(): x.requires_grad_() # permute to match dimension for pre-trained models y = self.value_function(x.permute(0, 2, 1), timesteps).sample grad = torch.autograd.grad([y.sum()], [x])[0] posterior_variance = self.scheduler._get_variance(i) model_std = torch.exp(0.5 * posterior_variance) grad = model_std * grad grad[timesteps < 2] = 0 x = x.detach() x = x + scale * grad x = self.reset_x0(x, conditions, self.action_dim) prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) # TODO: verify deprecation of this kwarg x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] # apply conditions to the trajectory (set the initial state) x = self.reset_x0(x, conditions, self.action_dim) x = self.to_torch(x) return x, y def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): # normalize the observations and create batch dimension obs = self.normalize(obs, "observations") obs = obs[None].repeat(batch_size, axis=0) conditions = {0: self.to_torch(obs)} shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) # generate initial noise and apply our conditions (to make the trajectories start at current state) x1 = randn_tensor(shape, device=self.unet.device) x = self.reset_x0(x1, conditions, self.action_dim) x = self.to_torch(x) # run the diffusion process x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) # sort output trajectories by value sorted_idx = y.argsort(0, descending=True).squeeze() sorted_values = x[sorted_idx] actions = sorted_values[:, :, : self.action_dim] actions = actions.detach().cpu().numpy() denorm_actions = self.de_normalize(actions, key="actions") # select the action with the highest value if y is not None: selected_index = 0 else: # if we didn't run value guiding, select a random action selected_index = np.random.randint(0, batch_size) denorm_actions = denorm_actions[selected_index, 0] return denorm_actions ================================================ FILE: 6DoF/diffusers/image_processor.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import List, Optional, Union import numpy as np import PIL import torch from PIL import Image from .configuration_utils import ConfigMixin, register_to_config from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate class VaeImageProcessor(ConfigMixin): """ Image processor for VAE. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. resample (`str`, *optional*, defaults to `lanczos`): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. do_convert_rgb (`bool`, *optional*, defaults to be `False`): Whether to convert the images to RGB format. """ config_name = CONFIG_NAME @register_to_config def __init__( self, do_resize: bool = True, vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = True, do_convert_rgb: bool = False, ): super().__init__() @staticmethod def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images @staticmethod def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: """ Convert a PIL image or a list of PIL images to NumPy arrays. """ if not isinstance(images, list): images = [images] images = [np.array(image).astype(np.float32) / 255.0 for image in images] images = np.stack(images, axis=0) return images @staticmethod def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: """ Convert a NumPy image to a PyTorch tensor. """ if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images @staticmethod def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: """ Convert a PyTorch tensor to a NumPy image. """ images = images.cpu().permute(0, 2, 3, 1).float().numpy() return images @staticmethod def normalize(images): """ Normalize an image array to [-1,1]. """ return 2.0 * images - 1.0 @staticmethod def denormalize(images): """ Denormalize an image array to [0,1]. """ return (images / 2 + 0.5).clamp(0, 1) @staticmethod def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: """ Converts an image to RGB format. """ image = image.convert("RGB") return image def resize( self, image: PIL.Image.Image, height: Optional[int] = None, width: Optional[int] = None, ) -> PIL.Image.Image: """ Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`. """ if height is None: height = image.height if width is None: width = image.width width, height = ( x - x % self.config.vae_scale_factor for x in (width, height) ) # resize to integer multiple of vae_scale_factor image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) return image def preprocess( self, image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], height: Optional[int] = None, width: Optional[int] = None, ) -> torch.Tensor: """ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. """ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) if isinstance(image, supported_formats): image = [image] elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" ) if isinstance(image[0], PIL.Image.Image): if self.config.do_convert_rgb: image = [self.convert_to_rgb(i) for i in image] if self.config.do_resize: image = [self.resize(i, height, width) for i in image] image = self.pil_to_numpy(image) # to np image = self.numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = self.numpy_to_pt(image) _, _, height, width = image.shape if self.config.do_resize and ( height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 ): raise ValueError( f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}" f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" ) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) _, channel, height, width = image.shape # don't need any preprocess if the image is latents if channel == 4: return image if self.config.do_resize and ( height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 ): raise ValueError( f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}" f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" ) # expected range [0,1], normalize to [-1,1] do_normalize = self.config.do_normalize if image.min() < 0: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", FutureWarning, ) do_normalize = False if do_normalize: image = self.normalize(image) return image def postprocess( self, image: torch.FloatTensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, ): if not isinstance(image, torch.Tensor): raise ValueError( f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" ) if output_type not in ["latent", "pt", "np", "pil"]: deprecation_message = ( f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " "`pil`, `np`, `pt`, `latent`" ) deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) output_type = "np" if output_type == "latent": return image if do_denormalize is None: do_denormalize = [self.config.do_normalize] * image.shape[0] image = torch.stack( [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] ) if output_type == "pt": return image image = self.pt_to_numpy(image) if output_type == "np": return image if output_type == "pil": return self.numpy_to_pil(image) class VaeImageProcessorLDM3D(VaeImageProcessor): """ Image processor for VAE LDM3D. Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. resample (`str`, *optional*, defaults to `lanczos`): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. """ config_name = CONFIG_NAME @register_to_config def __init__( self, do_resize: bool = True, vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = True, ): super().__init__() @staticmethod def numpy_to_pil(images): """ Convert a NumPy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image[:, :, :3]) for image in images] return pil_images @staticmethod def rgblike_to_depthmap(image): """ Args: image: RGB-like depth image Returns: depth map """ return image[:, :, 1] * 2**8 + image[:, :, 2] def numpy_to_depth(self, images): """ Convert a NumPy depth image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images_depth = images[:, :, :, 3:] if images.shape[-1] == 6: images_depth = (images_depth * 255).round().astype("uint8") pil_images = [ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth ] elif images.shape[-1] == 4: images_depth = (images_depth * 65535.0).astype(np.uint16) pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth] else: raise Exception("Not supported") return pil_images def postprocess( self, image: torch.FloatTensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, ): if not isinstance(image, torch.Tensor): raise ValueError( f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" ) if output_type not in ["latent", "pt", "np", "pil"]: deprecation_message = ( f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " "`pil`, `np`, `pt`, `latent`" ) deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) output_type = "np" if do_denormalize is None: do_denormalize = [self.config.do_normalize] * image.shape[0] image = torch.stack( [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] ) image = self.pt_to_numpy(image) if output_type == "np": if image.shape[-1] == 6: image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0) else: image_depth = image[:, :, :, 3:] return image[:, :, :, :3], image_depth if output_type == "pil": return self.numpy_to_pil(image), self.numpy_to_depth(image) else: raise Exception(f"This type {output_type} is not supported") ================================================ FILE: 6DoF/diffusers/loaders.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import warnings from collections import defaultdict from pathlib import Path from typing import Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from .models.attention_processor import ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, LoRAAttnAddedKVProcessor, LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, SlicedAttnAddedKVProcessor, XFormersAttnProcessor, ) from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, TEXT_ENCODER_ATTN_MODULE, _get_model_file, deprecate, is_safetensors_available, is_transformers_available, logging, ) if is_safetensors_available(): import safetensors if is_transformers_available(): from transformers import PreTrainedModel, PreTrainedTokenizer logger = logging.get_logger(__name__) TEXT_ENCODER_NAME = "text_encoder" UNET_NAME = "unet" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" class AttnProcsLayers(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): super().__init__() self.layers = torch.nn.ModuleList(state_dict.values()) self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} # .processor for unet, .self_attn for text encoder self.split_keys = [".processor", ".self_attn"] # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` def map_to(module, state_dict, *args, **kwargs): new_state_dict = {} for key, value in state_dict.items(): num = int(key.split(".")[1]) # 0 is always "layers" new_key = key.replace(f"layers.{num}", module.mapping[num]) new_state_dict[new_key] = value return new_state_dict def remap_key(key, state_dict): for k in self.split_keys: if k in key: return key.split(k)[0] + k raise ValueError( f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." ) def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) for key in all_keys: replace_key = remap_key(key, state_dict) new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") state_dict[new_key] = state_dict[key] del state_dict[key] self._register_state_dict_hook(map_to) self._register_load_state_dict_pre_hook(map_from, with_module=True) class UNet2DConditionLoadersMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be defined in [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py) and be a `torch.nn.Module` class. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a directory (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except IOError as e: if not allow_pickle: raise e # try loading non-safetensors weights pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path_or_dict # fill attn processors attn_processors = {} is_lora = all("lora" in k for k in state_dict.keys()) is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: is_new_lora_format = all( key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() ) if is_new_lora_format: # Strip the `"unet"` prefix. is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys()) if is_text_encoder_present: warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)." warnings.warn(warn_message) unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)] state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} lora_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in lora_grouped_dict.items(): rank = value_dict["to_k_lora.down.weight"].shape[0] hidden_size = value_dict["to_k_lora.up.weight"].shape[0] attn_processor = self for sub_key in key.split("."): attn_processor = getattr(attn_processor, sub_key) if isinstance( attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) ): cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1] attn_processor_class = LoRAAttnAddedKVProcessor else: cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): attn_processor_class = LoRAXFormersAttnProcessor else: attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) attn_processors[key] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank, network_alpha=network_alpha, ) attn_processors[key].load_state_dict(value_dict) elif is_custom_diffusion: custom_diffusion_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): if len(value) == 0: custom_diffusion_grouped_dict[key] = {} else: if "to_out" in key: attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) else: attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:]) custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in custom_diffusion_grouped_dict.items(): if len(value_dict) == 0: attn_processors[key] = CustomDiffusionAttnProcessor( train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None ) else: cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False attn_processors[key] = CustomDiffusionAttnProcessor( train_kv=True, train_q_out=train_q_out, hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, ) attn_processors[key].load_state_dict(value_dict) else: raise ValueError( f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." ) # set correct dtype & device attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()} # set layers self.set_attn_processor(attn_processors) def save_attn_procs( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = False, **kwargs, ): r""" Save an attention processor to a directory so that it can be reloaded using the [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save an attention processor to. Will be created if it doesn't exist. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. """ weight_name = weight_name or deprecate( "weights_name", "0.20.0", "`weights_name` is deprecated, please use `weight_name` instead.", take_from=kwargs, ) if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return if save_function is None: if safe_serialization: def save_function(weights, filename): return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) else: save_function = torch.save os.makedirs(save_directory, exist_ok=True) is_custom_diffusion = any( isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) for (_, x) in self.attn_processors.items() ) if is_custom_diffusion: model_to_save = AttnProcsLayers( { y: x for (y, x) in self.attn_processors.items() if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) } ) state_dict = model_to_save.state_dict() for name, attn in self.attn_processors.items(): if len(attn.state_dict()) == 0: state_dict[name] = {} else: model_to_save = AttnProcsLayers(self.attn_processors) state_dict = model_to_save.state_dict() if weight_name is None: if safe_serialization: weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE else: weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME # Save the model save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") class TextualInversionLoaderMixin: r""" Load textual inversion tokens and embeddings to the tokenizer and text encoder. """ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): r""" Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual inversion token or if the textual inversion token is a single vector, the input prompt is returned. Parameters: prompt (`str` or list of `str`): The prompt or prompts to guide the image generation. tokenizer (`PreTrainedTokenizer`): The tokenizer responsible for encoding the prompt into input tokens. Returns: `str` or list of `str`: The converted prompt """ if not isinstance(prompt, List): prompts = [prompt] else: prompts = prompt prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts] if not isinstance(prompt, List): return prompts[0] return prompts def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): r""" Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds to a multi-vector textual inversion embedding, this function will process the prompt so that the special token is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual inversion token or a textual inversion token that is a single vector, the input prompt is simply returned. Parameters: prompt (`str`): The prompt to guide the image generation. tokenizer (`PreTrainedTokenizer`): The tokenizer responsible for encoding the prompt into input tokens. Returns: `str`: The converted prompt """ tokens = tokenizer.tokenize(prompt) unique_tokens = set(tokens) for token in unique_tokens: if token in tokenizer.added_tokens_encoder: replacement = token i = 1 while f"{token}_{i}" in tokenizer.added_tokens_encoder: replacement += f" {token}_{i}" i += 1 prompt = prompt.replace(token, replacement) return prompt def load_textual_inversion( self, pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], token: Optional[Union[str, List[str]]] = None, **kwargs, ): r""" Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and Automatic1111 formats are supported). Parameters: pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`): Can be either one of the following or a list of them: - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual inversion weights. - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). token (`str` or `List[str]`, *optional*): Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a list, then `token` must also be a list of equal length. weight_name (`str`, *optional*): Name of a custom weight file. This should be used when: - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight name such as `text_inv.bin`. - The saved textual inversion file is in the Automatic1111 format. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. Example: To load a textual inversion embedding vector in 🤗 Diffusers format: ```py from diffusers import StableDiffusionPipeline import torch model_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") pipe.load_textual_inversion("sd-concepts-library/cat-toy") prompt = "A backpack" image = pipe(prompt, num_inference_steps=50).images[0] image.save("cat-backpack.png") ``` To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector locally: ```py from diffusers import StableDiffusionPipeline import torch model_id = "runwayml/stable-diffusion-v1-5" pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2") prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details." image = pipe(prompt, num_inference_steps=50).images[0] image.save("character.png") ``` """ if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): raise ValueError( f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling" f" `{self.load_textual_inversion.__name__}`" ) if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel): raise ValueError( f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling" f" `{self.load_textual_inversion.__name__}`" ) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True user_agent = { "file_type": "text_inversion", "framework": "pytorch", } if not isinstance(pretrained_model_name_or_path, list): pretrained_model_name_or_paths = [pretrained_model_name_or_path] else: pretrained_model_name_or_paths = pretrained_model_name_or_path if isinstance(token, str): tokens = [token] elif token is None: tokens = [None] * len(pretrained_model_name_or_paths) else: tokens = token if len(pretrained_model_name_or_paths) != len(tokens): raise ValueError( f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}" f"Make sure both lists have the same length." ) valid_tokens = [t for t in tokens if t is not None] if len(set(valid_tokens)) < len(valid_tokens): raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") token_ids_and_embeddings = [] for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): if not isinstance(pretrained_model_name_or_path, dict): # 1. Load textual inversion file model_file = None # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except Exception as e: if not allow_pickle: raise e model_file = None if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=weight_name or TEXT_INVERSION_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path # 2. Load token and embedding correcly from file loaded_token = None if isinstance(state_dict, torch.Tensor): if token is None: raise ValueError( "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." ) embedding = state_dict elif len(state_dict) == 1: # diffusers loaded_token, embedding = next(iter(state_dict.items())) elif "string_to_param" in state_dict: # A1111 loaded_token = state_dict["name"] embedding = state_dict["string_to_param"]["*"] if token is not None and loaded_token != token: logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") else: token = loaded_token embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) # 3. Make sure we don't mess up the tokenizer or text encoder vocab = self.tokenizer.get_vocab() if token in vocab: raise ValueError( f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." ) elif f"{token}_1" in vocab: multi_vector_tokens = [token] i = 1 while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: multi_vector_tokens.append(f"{token}_{i}") i += 1 raise ValueError( f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." ) is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 if is_multi_vector: tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] embeddings = [e for e in embedding] # noqa: C416 else: tokens = [token] embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] # add tokens and get ids self.tokenizer.add_tokens(tokens) token_ids = self.tokenizer.convert_tokens_to_ids(tokens) token_ids_and_embeddings += zip(token_ids, embeddings) logger.info(f"Loaded textual inversion embedding for {token}.") # resize token embeddings and set all new embeddings self.text_encoder.resize_token_embeddings(len(self.tokenizer)) for token_id, embedding in token_ids_and_embeddings: self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding class LoraLoaderMixin: r""" Load LoRA layers into [`UNet2DConditionModel`] and [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). """ text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) # set lora scale to a reasonable default self._lora_scale = 1.0 if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except IOError as e: if not allow_pickle: raise e # try loading non-safetensors weights pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path_or_dict # Convert kohya-ss Style LoRA attn procs to diffusers attn procs network_alpha = None if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys): # Load the layers corresponding to UNet. unet_keys = [k for k in keys if k.startswith(self.unet_name)] logger.info(f"Loading {self.unet_name}.") unet_lora_state_dict = { k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys } self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] text_encoder_lora_state_dict = { k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: logger.info(f"Loading {self.text_encoder_name}.") attn_procs_text_encoder = self._load_text_encoder_attn_procs( text_encoder_lora_state_dict, network_alpha=network_alpha ) self._modify_text_encoder(attn_procs_text_encoder) # save lora attn procs of text encoder so that it can be easily retrieved self._text_encoder_lora_attn_procs = attn_procs_text_encoder # Otherwise, we're dealing with the old format. This means the `state_dict` should only # contain the module names of the `unet` as its keys WITHOUT any prefix. elif not all( key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() ): self.unet.load_attn_procs(state_dict) warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warnings.warn(warn_message) @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 @property def text_encoder_lora_attn_procs(self): if hasattr(self, "_text_encoder_lora_attn_procs"): return self._text_encoder_lora_attn_procs return def _remove_text_encoder_monkey_patch(self): # Loop over the CLIPAttention module of text_encoder for name, attn_module in self.text_encoder.named_modules(): if name.endswith(TEXT_ENCODER_ATTN_MODULE): # Loop over the LoRA layers for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): # Retrieve the q/k/v/out projection of CLIPAttention module = attn_module.get_submodule(text_encoder_attr) if hasattr(module, "old_forward"): # restore original `forward` to remove monkey-patch module.forward = module.old_forward delattr(module, "old_forward") def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): r""" Monkey-patches the forward passes of attention modules of the text encoder. Parameters: attn_processors: Dict[str, `LoRAAttnProcessor`]: A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. """ # First, remove any monkey-patch that might have been applied before self._remove_text_encoder_monkey_patch() # Loop over the CLIPAttention module of text_encoder for name, attn_module in self.text_encoder.named_modules(): if name.endswith(TEXT_ENCODER_ATTN_MODULE): # Loop over the LoRA layers for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer. module = attn_module.get_submodule(text_encoder_attr) lora_layer = attn_processors[name].get_submodule(attn_proc_attr) # save old_forward to module that can be used to remove monkey-patch old_forward = module.old_forward = module.forward # create a new scope that locks in the old_forward, lora_layer value for each new_forward function # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 def make_new_forward(old_forward, lora_layer): def new_forward(x): result = old_forward(x) + self.lora_scale * lora_layer(x) return result return new_forward # Monkey-patch. module.forward = make_new_forward(old_forward, lora_layer) @property def _lora_attn_processor_attr_to_text_encoder_attr(self): return { "to_q_lora": "q_proj", "to_k_lora": "k_proj", "to_v_lora": "v_proj", "to_out_lora": "out_proj", } def _load_text_encoder_attn_procs( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs ): r""" Load pretrained attention processor layers for [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., `./my_model_directory/`. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `diffusers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. mirror (`str`, *optional*): Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. Returns: `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding [`LoRAAttnProcessor`]. It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True user_agent = { "file_type": "attn_procs_weights", "framework": "pytorch", } model_file = None if not isinstance(pretrained_model_name_or_path_or_dict, dict): # Let's first try to load .safetensors weights if (use_safetensors and weight_name is None) or ( weight_name is not None and weight_name.endswith(".safetensors") ): try: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = safetensors.torch.load_file(model_file, device="cpu") except IOError as e: if not allow_pickle: raise e # try loading non-safetensors weights pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path_or_dict, weights_name=weight_name or LORA_WEIGHT_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, ) state_dict = torch.load(model_file, map_location="cpu") else: state_dict = pretrained_model_name_or_path_or_dict # fill attn processors attn_processors = {} is_lora = all("lora" in k for k in state_dict.keys()) if is_lora: lora_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value for key, value_dict in lora_grouped_dict.items(): rank = value_dict["to_k_lora.down.weight"].shape[0] cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] hidden_size = value_dict["to_k_lora.up.weight"].shape[0] attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) attn_processors[key] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank, network_alpha=network_alpha, ) attn_processors[key].load_state_dict(value_dict) else: raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.") # set correct dtype & device attn_processors = { k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items() } return attn_processors @classmethod def save_lora_weights( self, save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = False, ): r""" Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the UNet. text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text encoder LoRA state dict because it comes 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return if save_function is None: if safe_serialization: def save_function(weights, filename): return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) else: save_function = torch.save os.makedirs(save_directory, exist_ok=True) # Create a flat dictionary. state_dict = {} if unet_lora_layers is not None: weights = ( unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers ) unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()} state_dict.update(unet_lora_state_dict) if text_encoder_lora_layers is not None: weights = ( text_encoder_lora_layers.state_dict() if isinstance(text_encoder_lora_layers, torch.nn.Module) else text_encoder_lora_layers ) text_encoder_lora_state_dict = { f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() } state_dict.update(text_encoder_lora_state_dict) # Save the model if weight_name is None: if safe_serialization: weight_name = LORA_WEIGHT_NAME_SAFE else: weight_name = LORA_WEIGHT_NAME save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") def _convert_kohya_lora_to_diffusers(self, state_dict): unet_state_dict = {} te_state_dict = {} network_alpha = None for key, value in state_dict.items(): if "lora_down" in key: lora_name = key.split(".")[0] lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" if lora_name_alpha in state_dict: alpha = state_dict[lora_name_alpha].item() if network_alpha is None: network_alpha = alpha elif network_alpha != alpha: raise ValueError("Network alpha is not consistent") if lora_name.startswith("lora_unet_"): diffusers_name = key.replace("lora_unet_", "").replace("_", ".") diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") diffusers_name = diffusers_name.replace("mid.block", "mid_block") diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") if "transformer_blocks" in diffusers_name: if "attn1" in diffusers_name or "attn2" in diffusers_name: diffusers_name = diffusers_name.replace("attn1", "attn1.processor") diffusers_name = diffusers_name.replace("attn2", "attn2.processor") unet_state_dict[diffusers_name] = value unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] elif lora_name.startswith("lora_te_"): diffusers_name = key.replace("lora_te_", "").replace("_", ".") diffusers_name = diffusers_name.replace("text.model", "text_model") diffusers_name = diffusers_name.replace("self.attn", "self_attn") diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") if "self_attn" in diffusers_name: te_state_dict[diffusers_name] = value te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} new_state_dict = {**unet_state_dict, **te_state_dict} return new_state_dict, network_alpha class FromSingleFileMixin: """ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. """ @classmethod def from_ckpt(cls, *args, **kwargs): deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead." deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False) return cls.from_single_file(*args, **kwargs) @classmethod def from_single_file(cls, pretrained_model_link_or_path, **kwargs): r""" Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline is set in evaluation mode (`model.eval()`) by default. Parameters: pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A link to the `.ckpt` file (for example `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - A path to a *file* containing all pipeline weights. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to True, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. extract_ema (`bool`, *optional*, defaults to `False`): Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield higher quality images for inference. Non-EMA weights are usually better to continue finetuning. upcast_attention (`bool`, *optional*, defaults to `None`): Whether the attention computation should always be upcasted. image_size (`int`, *optional*, defaults to 512): The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable Diffusion v2 base model. Use 768 for Stable Diffusion v2. prediction_type (`str`, *optional*): The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2. num_in_channels (`int`, *optional*, defaults to `None`): The number of input channels. If `None`, it will be automatically inferred. scheduler_type (`str`, *optional*, defaults to `"pndm"`): Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", "ddim"]`. load_safety_checker (`bool`, *optional*, defaults to `True`): Whether to load the safety checker or not. text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): An instance of [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if needed. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (for example the pipeline components of the specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` method. See example below for more information. Examples: ```py >>> from diffusers import StableDiffusionPipeline >>> # Download pipeline from huggingface.co and cache. >>> pipeline = StableDiffusionPipeline.from_single_file( ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" ... ) >>> # Download pipeline from local file >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly") >>> # Enable float16 and move to GPU >>> pipeline = StableDiffusionPipeline.from_single_file( ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", ... torch_dtype=torch.float16, ... ) >>> pipeline.to("cuda") ``` """ # import here to avoid circular dependency from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) extract_ema = kwargs.pop("extract_ema", False) image_size = kwargs.pop("image_size", None) scheduler_type = kwargs.pop("scheduler_type", "pndm") num_in_channels = kwargs.pop("num_in_channels", None) upcast_attention = kwargs.pop("upcast_attention", None) load_safety_checker = kwargs.pop("load_safety_checker", True) prediction_type = kwargs.pop("prediction_type", None) text_encoder = kwargs.pop("text_encoder", None) tokenizer = kwargs.pop("tokenizer", None) torch_dtype = kwargs.pop("torch_dtype", None) use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) pipeline_name = cls.__name__ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] from_safetensors = file_extension == "safetensors" if from_safetensors and use_safetensors is False: raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") # TODO: For now we only support stable diffusion stable_unclip = None model_type = None controlnet = False if pipeline_name == "StableDiffusionControlNetPipeline": # Model type will be inferred from the checkpoint. controlnet = True elif "StableDiffusion" in pipeline_name: # Model type will be inferred from the checkpoint. pass elif pipeline_name == "StableUnCLIPPipeline": model_type = "FrozenOpenCLIPEmbedder" stable_unclip = "txt2img" elif pipeline_name == "StableUnCLIPImg2ImgPipeline": model_type = "FrozenOpenCLIPEmbedder" stable_unclip = "img2img" elif pipeline_name == "PaintByExamplePipeline": model_type = "PaintByExample" elif pipeline_name == "LDMTextToImagePipeline": model_type = "LDMTextToImage" else: raise ValueError(f"Unhandled pipeline class: {pipeline_name}") # remove huggingface url for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: if pretrained_model_link_or_path.startswith(prefix): pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained ckpt_path = Path(pretrained_model_link_or_path) if not ckpt_path.is_file(): # get repo_id and (potentially nested) file path of ckpt in repo repo_id = "/".join(ckpt_path.parts[:2]) file_path = "/".join(ckpt_path.parts[2:]) if file_path.startswith("blob/"): file_path = file_path[len("blob/") :] if file_path.startswith("main/"): file_path = file_path[len("main/") :] pretrained_model_link_or_path = hf_hub_download( repo_id, filename=file_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, force_download=force_download, ) pipe = download_from_original_stable_diffusion_ckpt( pretrained_model_link_or_path, pipeline_class=cls, model_type=model_type, stable_unclip=stable_unclip, controlnet=controlnet, from_safetensors=from_safetensors, extract_ema=extract_ema, image_size=image_size, scheduler_type=scheduler_type, num_in_channels=num_in_channels, upcast_attention=upcast_attention, load_safety_checker=load_safety_checker, prediction_type=prediction_type, text_encoder=text_encoder, tokenizer=tokenizer, ) if torch_dtype is not None: pipe.to(torch_dtype=torch_dtype) return pipe ================================================ FILE: 6DoF/diffusers/models/__init__.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..utils import is_flax_available, is_torch_available if is_torch_available(): from .autoencoder_kl import AutoencoderKL from .controlnet import ControlNetModel from .dual_transformer_2d import DualTransformer2DModel from .modeling_utils import ModelMixin from .prior_transformer import PriorTransformer from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel from .vq_model import VQModel if is_flax_available(): from .controlnet_flax import FlaxControlNetModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL ================================================ FILE: 6DoF/diffusers/models/activations.py ================================================ from torch import nn def get_activation(act_fn): if act_fn in ["swish", "silu"]: return nn.SiLU() elif act_fn == "mish": return nn.Mish() elif act_fn == "gelu": return nn.GELU() else: raise ValueError(f"Unsupported activation function: {act_fn}") ================================================ FILE: 6DoF/diffusers/models/attention.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import nn from ..utils import maybe_allow_in_graph from .activations import get_activation from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", final_dropout: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # Define 3 blocks. Each block has its own normalization layer. # 1. Self-Attn if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) elif self.use_ada_layer_norm_zero: self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): # Sets chunk feed-forward self._chunk_size = chunk_size self._chunk_dim = dim def forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, timestep: Optional[torch.LongTensor] = None, posemb: Optional = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, ): # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) else: norm_hidden_states = self.norm1(hidden_states) cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, posemb=posemb, # todo in self attn, posemb shoule be [pose_in, pose_in]? **cross_attention_kwargs, ) if self.use_ada_layer_norm_zero: attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states # 2. Cross-Attention if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, posemb=posemb, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # 3. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: raise ValueError( f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], dim=self._chunk_dim, ) else: ff_output = self.ff(norm_hidden_states) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states return hidden_states class FeedForward(nn.Module): r""" A feed-forward layer. Parameters: dim (`int`): The number of channels in the input. dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. """ def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) if activation_fn == "gelu-approximate": act_fn = GELU(dim, inner_dim, approximate="tanh") elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": act_fn = ApproximateGELU(dim, inner_dim) self.net = nn.ModuleList([]) # project in self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states): for module in self.net: hidden_states = module(hidden_states) return hidden_states class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. """ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): super().__init__() self.proj = nn.Linear(dim_in, dim_out) self.approximate = approximate def gelu(self, gate): if gate.device.type != "mps": return F.gelu(gate, approximate=self.approximate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states = self.gelu(hidden_states) return hidden_states class GEGLU(nn.Module): r""" A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def gelu(self, gate): if gate.device.type != "mps": return F.gelu(gate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) return hidden_states * self.gelu(gate) class ApproximateGELU(nn.Module): """ The approximate form of Gaussian Error Linear Unit (GELU) For more details, see section 2: https://arxiv.org/abs/1606.08415 """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out) def forward(self, x): x = self.proj(x) return x * torch.sigmoid(1.702 * x) class AdaLayerNorm(nn.Module): """ Norm layer modified to incorporate timestep embeddings. """ def __init__(self, embedding_dim, num_embeddings): super().__init__() self.emb = nn.Embedding(num_embeddings, embedding_dim) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, embedding_dim * 2) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) def forward(self, x, timestep): emb = self.linear(self.silu(self.emb(timestep))) scale, shift = torch.chunk(emb, 2) x = self.norm(x) * (1 + scale) + shift return x class AdaLayerNormZero(nn.Module): """ Norm layer adaptive layer norm zero (adaLN-Zero). """ def __init__(self, embedding_dim, num_embeddings): super().__init__() self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) def forward(self, x, timestep, class_labels, hidden_dtype=None): emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp class AdaGroupNorm(nn.Module): """ GroupNorm layer modified to incorporate timestep embeddings. """ def __init__( self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 ): super().__init__() self.num_groups = num_groups self.eps = eps if act_fn is None: self.act = None else: self.act = get_activation(act_fn) self.linear = nn.Linear(embedding_dim, out_dim * 2) def forward(self, x, emb): if self.act: emb = self.act(emb) emb = self.linear(emb) emb = emb[:, :, None, None] scale, shift = emb.chunk(2, dim=1) x = F.group_norm(x, self.num_groups, eps=self.eps) x = x * (1 + scale) + shift return x ================================================ FILE: 6DoF/diffusers/models/attention_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import math import flax.linen as nn import jax import jax.numpy as jnp def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): """Multi-head dot product attention with a limited number of queries.""" num_kv, num_heads, k_features = key.shape[-3:] v_features = value.shape[-1] key_chunk_size = min(key_chunk_size, num_kv) query = query / jnp.sqrt(k_features) @functools.partial(jax.checkpoint, prevent_cse=False) def summarize_chunk(query, key, value): attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision) max_score = jnp.max(attn_weights, axis=-1, keepdims=True) max_score = jax.lax.stop_gradient(max_score) exp_weights = jnp.exp(attn_weights - max_score) exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision) max_score = jnp.einsum("...qhk->...qh", max_score) return (exp_values, exp_weights.sum(axis=-1), max_score) def chunk_scanner(chunk_idx): # julienne key array key_chunk = jax.lax.dynamic_slice( operand=key, start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d] slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d] ) # julienne value array value_chunk = jax.lax.dynamic_slice( operand=value, start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d] slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d] ) return summarize_chunk(query, key_chunk, value_chunk) chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) global_max = jnp.max(chunk_max, axis=0, keepdims=True) max_diffs = jnp.exp(chunk_max - global_max) chunk_values *= jnp.expand_dims(max_diffs, axis=-1) chunk_weights *= max_diffs all_values = chunk_values.sum(axis=0) all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) return all_values / all_weights def jax_memory_efficient_attention( query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 ): r""" Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 https://github.com/AminRezaei0x443/memory-efficient-attention Args: query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head) key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head) value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head) precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`): numerical precision for computation query_chunk_size (`int`, *optional*, defaults to 1024): chunk size to divide query array value must divide query_length equally without remainder key_chunk_size (`int`, *optional*, defaults to 4096): chunk size to divide key and value array value must divide key_value_length equally without remainder Returns: (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head) """ num_q, num_heads, q_features = query.shape[-3:] def chunk_scanner(chunk_idx, _): # julienne query array query_chunk = jax.lax.dynamic_slice( operand=query, start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d] slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d] ) return ( chunk_idx + query_chunk_size, # unused ignore it _query_chunk_attention( query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size ), ) _, res = jax.lax.scan( f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter ) return jnp.concatenate(res, axis=-3) # fuse the chunked result back class FlaxAttention(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 Parameters: query_dim (:obj:`int`): Input hidden states dimension heads (:obj:`int`, *optional*, defaults to 8): Number of heads dim_head (:obj:`int`, *optional*, defaults to 64): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ query_dim: int heads: int = 8 dim_head: int = 64 dropout: float = 0.0 use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): inner_dim = self.dim_head * self.heads self.scale = self.dim_head**-0.5 # Weights were exported with old names {to_q, to_k, to_v, to_out} self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") self.dropout_layer = nn.Dropout(rate=self.dropout) def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def __call__(self, hidden_states, context=None, deterministic=True): context = hidden_states if context is None else context query_proj = self.query(hidden_states) key_proj = self.key(context) value_proj = self.value(context) query_states = self.reshape_heads_to_batch_dim(query_proj) key_states = self.reshape_heads_to_batch_dim(key_proj) value_states = self.reshape_heads_to_batch_dim(value_proj) if self.use_memory_efficient_attention: query_states = query_states.transpose(1, 0, 2) key_states = key_states.transpose(1, 0, 2) value_states = value_states.transpose(1, 0, 2) # this if statement create a chunk size for each layer of the unet # the chunk size is equal to the query_length dimension of the deepest layer of the unet flatten_latent_dim = query_states.shape[-3] if flatten_latent_dim % 64 == 0: query_chunk_size = int(flatten_latent_dim / 64) elif flatten_latent_dim % 16 == 0: query_chunk_size = int(flatten_latent_dim / 16) elif flatten_latent_dim % 4 == 0: query_chunk_size = int(flatten_latent_dim / 4) else: query_chunk_size = int(flatten_latent_dim) hidden_states = jax_memory_efficient_attention( query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 ) hidden_states = hidden_states.transpose(1, 0, 2) else: # compute attentions attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) attention_scores = attention_scores * self.scale attention_probs = nn.softmax(attention_scores, axis=2) # attend to values hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.proj_attn(hidden_states) return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxBasicTransformerBlock(nn.Module): r""" A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: https://arxiv.org/abs/1706.03762 Parameters: dim (:obj:`int`): Inner hidden states dimension n_heads (:obj:`int`): Number of heads d_head (:obj:`int`): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate only_cross_attention (`bool`, defaults to `False`): Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 """ dim: int n_heads: int d_head: int dropout: float = 0.0 only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 use_memory_efficient_attention: bool = False def setup(self): # self attention (or cross_attention if only_cross_attention is True) self.attn1 = FlaxAttention( self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype ) # cross attention self.attn2 = FlaxAttention( self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype ) self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states if self.only_cross_attention: hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) else: hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention residual = hidden_states hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) hidden_states = hidden_states + residual # feed forward residual = hidden_states hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxTransformer2DModel(nn.Module): r""" A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: https://arxiv.org/pdf/1506.02025.pdf Parameters: in_channels (:obj:`int`): Input number of channels n_heads (:obj:`int`): Number of heads d_head (:obj:`int`): Hidden states dimension inside each head depth (:obj:`int`, *optional*, defaults to 1): Number of transformers block dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate use_linear_projection (`bool`, defaults to `False`): tbd only_cross_attention (`bool`, defaults to `False`): tbd dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 """ in_channels: int n_heads: int d_head: int depth: int = 1 dropout: float = 0.0 use_linear_projection: bool = False only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 use_memory_efficient_attention: bool = False def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) inner_dim = self.n_heads * self.d_head if self.use_linear_projection: self.proj_in = nn.Dense(inner_dim, dtype=self.dtype) else: self.proj_in = nn.Conv( inner_dim, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.transformer_blocks = [ FlaxBasicTransformerBlock( inner_dim, self.n_heads, self.d_head, dropout=self.dropout, only_cross_attention=self.only_cross_attention, dtype=self.dtype, use_memory_efficient_attention=self.use_memory_efficient_attention, ) for _ in range(self.depth) ] if self.use_linear_projection: self.proj_out = nn.Dense(inner_dim, dtype=self.dtype) else: self.proj_out = nn.Conv( inner_dim, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height * width, channels) hidden_states = self.proj_in(hidden_states) else: hidden_states = self.proj_in(hidden_states) hidden_states = hidden_states.reshape(batch, height * width, channels) for transformer_block in self.transformer_blocks: hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) if self.use_linear_projection: hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, channels) else: hidden_states = hidden_states.reshape(batch, height, width, channels) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual return self.dropout_layer(hidden_states, deterministic=deterministic) class FlaxFeedForward(nn.Module): r""" Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's [`FeedForward`] class, with the following simplifications: - The activation function is currently hardcoded to a gated linear unit from: https://arxiv.org/abs/2002.05202 - `dim_out` is equal to `dim`. - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`]. Parameters: dim (:obj:`int`): Inner hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): # The second linear layer needs to be called # net_2 for now to match the index of the Sequential layer self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) self.net_2 = nn.Dense(self.dim, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): hidden_states = self.net_0(hidden_states, deterministic=deterministic) hidden_states = self.net_2(hidden_states) return hidden_states class FlaxGEGLU(nn.Module): r""" Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim (:obj:`int`): Input hidden states dimension dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): inner_dim = self.dim * 4 self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, deterministic=True): hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic) ================================================ FILE: 6DoF/diffusers/models/attention_processor.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional, Union import torch import torch.nn.functional as F from torch import nn from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_xformers_available(): import xformers import xformers.ops else: xformers = None # 6DoF CaPE import einops def cape_embed(f, P): # f is feature vector of shape [..., d] # P is 4x4 transformation matrix f = einops.rearrange(f, '... (d k) -> ... d k', k=4) return einops.rearrange(f@P, '... d k -> ... (d k)', k=4) @maybe_allow_in_graph class Attention(nn.Module): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, ): super().__init__() inner_dim = dim_head * heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection self.dropout = dropout # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention if self.added_kv_proj_dim is None and self.only_cross_attention: raise ValueError( "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." ) if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None if spatial_norm_dim is not None: self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) else: self.spatial_norm = None if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(cross_attention_dim) elif cross_attention_norm == "group_norm": if self.added_kv_proj_dim is not None: # The given `encoder_hidden_states` are initially of shape # (batch_size, seq_len, added_kv_proj_dim) before being projected # to (batch_size, seq_len, cross_attention_dim). The norm is applied # before the projection, so we need to use `added_kv_proj_dim` as # the number of channels for the group norm. norm_cross_num_channels = added_kv_proj_dim else: norm_cross_num_channels = cross_attention_dim self.norm_cross = nn.GroupNorm( num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True ) else: raise ValueError( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): is_lora = hasattr(self, "processor") and isinstance( self.processor, (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor), ) is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) ) is_added_kv_processor = hasattr(self, "processor") and isinstance( self.processor, ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, SlicedAttnAddedKVProcessor, XFormersAttnAddedKVProcessor, LoRAAttnAddedKVProcessor, ), ) if use_memory_efficient_attention_xformers: if is_added_kv_processor and (is_lora or is_custom_diffusion): raise NotImplementedError( f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}" ) if not is_xformers_available(): raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers" ), name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e if is_lora: # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? processor = LoRAXFormersAttnProcessor( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, attention_op=attention_op, ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) elif is_custom_diffusion: processor = CustomDiffusionXFormersAttnProcessor( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, attention_op=attention_op, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) elif is_added_kv_processor: # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP # which uses this type of cross attention ONLY because the attention mask of format # [0, ..., -10.000, ..., 0, ...,] is not supported # throw warning logger.info( "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." ) processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: if is_lora: attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) processor = attn_processor_class( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) elif is_custom_diffusion: processor = CustomDiffusionAttnProcessor( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, ) processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_attention_slice(self, slice_size): if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") if slice_size is not None and self.added_kv_proj_dim is not None: processor = SlicedAttnAddedKVProcessor(slice_size) elif slice_size is not None: processor = SlicedAttnProcessor(slice_size) elif self.added_kv_proj_dim is not None: processor = AttnAddedKVProcessor() else: # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) def set_processor(self, processor: "AttnProcessor"): # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( hasattr(self, "processor") and isinstance(self.processor, torch.nn.Module) and not isinstance(processor, torch.nn.Module) ): logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") self._modules.pop("processor") self.processor = processor def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) def batch_to_head_dim(self, tensor): head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def head_to_batch_dim(self, tensor, out_dim=3): head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3) if out_dim == 3: tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def get_attention_scores(self, query, key, attention_mask=None): dtype = query.dtype if self.upcast_attention: query = query.float() key = key.float() if attention_mask is None: baddbmm_input = torch.empty( query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device ) beta = 0 else: baddbmm_input = attention_mask beta = 1 attention_scores = torch.baddbmm( baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale, ) del baddbmm_input if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) del attention_scores attention_probs = attention_probs.to(dtype) return attention_probs def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): if batch_size is None: deprecate( "batch_size=None", "0.0.15", ( "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" " `prepare_attention_mask` when preparing the attention_mask." ), ) batch_size = 1 head_size = self.heads if attention_mask is None: return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: # we want to instead pad by (0, remaining_length), where remaining_length is: # remaining_length: int = target_length - current_length # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: attention_mask = attention_mask.repeat_interleave(head_size, dim=0) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.repeat_interleave(head_size, dim=1) return attention_mask def norm_encoder_hidden_states(self, encoder_hidden_states): assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" if isinstance(self.norm_cross, nn.LayerNorm): encoder_hidden_states = self.norm_cross(encoder_hidden_states) elif isinstance(self.norm_cross, nn.GroupNorm): # Group norm norms along the channels dimension and expects # input to be in the shape of (N, C, *). In this case, we want # to norm along the hidden dimension, so we need to move # (batch_size, sequence_length, hidden_size) -> # (batch_size, hidden_size, sequence_length) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) encoder_hidden_states = self.norm_cross(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) else: assert False return encoder_hidden_states class AttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class LoRALinearLayer(nn.Module): def __init__(self, in_features, out_features, rank=4, network_alpha=None): super().__init__() if rank > min(in_features, out_features): raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") self.down = nn.Linear(in_features, rank, bias=False) self.up = nn.Linear(rank, out_features, bias=False) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning self.network_alpha = network_alpha self.rank = rank nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) def forward(self, hidden_states): orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) if self.network_alpha is not None: up_hidden_states *= self.network_alpha / self.rank return up_hidden_states.to(orig_dtype) class LoRAAttnProcessor(nn.Module): r""" Processor for implementing the LoRA attention mechanism. Args: hidden_size (`int`, *optional*): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. network_alpha (`int`, *optional*): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CustomDiffusionAttnProcessor(nn.Module): r""" Processor for implementing attention for the Custom Diffusion method. Args: train_kv (`bool`, defaults to `True`): Whether to newly train the key and value matrices corresponding to the text features. train_q_out (`bool`, defaults to `True`): Whether to newly train query matrices corresponding to the latent image features. hidden_size (`int`, *optional*, defaults to `None`): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*, defaults to `None`): The number of channels in the `encoder_hidden_states`. out_bias (`bool`, defaults to `True`): Whether to include the bias parameter in `train_q_out`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. """ def __init__( self, train_kv=True, train_q_out=True, hidden_size=None, cross_attention_dim=None, out_bias=True, dropout=0.0, ): super().__init__() self.train_kv = train_kv self.train_q_out = train_q_out self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim # `_custom_diffusion` id for easy serialization and loading. if self.train_kv: self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) if self.train_q_out: self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) self.to_out_custom_diffusion = nn.ModuleList([]) self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: query = self.to_q_custom_diffusion(hidden_states) else: query = attn.to_q(hidden_states) if encoder_hidden_states is None: crossattn = False encoder_hidden_states = hidden_states else: crossattn = True if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.train_kv: key = self.to_k_custom_diffusion(encoder_hidden_states) value = self.to_v_custom_diffusion(encoder_hidden_states) else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if crossattn: detach = torch.ones_like(key) detach[:, :1, :] = detach[:, :1, :] * 0.0 key = detach * key + (1 - detach) * key.detach() value = detach * value + (1 - detach) * value.detach() query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) if self.train_q_out: # linear proj hidden_states = self.to_out_custom_diffusion[0](hidden_states) # dropout hidden_states = self.to_out_custom_diffusion[1](hidden_states) else: # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class AttnAddedKVProcessor: r""" Processor for performing attention-related computations with extra learnable key and value matrices for the text encoder. """ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class AttnAddedKVProcessor2_0: r""" Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra learnable key and value matrices for the text encoder. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query, out_dim=4) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class LoRAAttnAddedKVProcessor(nn.Module): r""" Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text encoder. Args: hidden_size (`int`, *optional*): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*, defaults to `None`): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora( encoder_hidden_states ) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora( encoder_hidden_states ) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states) value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class XFormersAttnAddedKVProcessor: r""" Processor for implementing memory efficient attention using xFormers. Args: attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class XFormersAttnProcessor: r""" Processor for implementing memory efficient attention using xFormers. Args: attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, posemb: Optional = None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) if posemb is not None: # turn 2d attention into multiview attention self_attn = encoder_hidden_states is None # check if self attn or cross attn [p_out, p_out_inv], [p_in, p_in_inv] = posemb t_out, t_in = p_out.shape[1], p_in.shape[1] # t size hidden_states = einops.rearrange(hidden_states, '(b t_out) l d -> b (t_out l) d', t_out=t_out) batch_size, key_tokens, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # apply 6DoF, todo now only for xformer processor if posemb is not None: p_out_inv = einops.repeat(p_out_inv, 'b t_out f g -> b (t_out l) f g', l=query.shape[1] // t_out) # query shape if self_attn: p_in = einops.repeat(p_out, 'b t_out f g -> b (t_out l) f g', l=query.shape[1] // t_out) # query shape else: p_in = einops.repeat(p_in, 'b t_in f g -> b (t_in l) f g', l=key.shape[1] // t_in) # key shape query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T) .permute(0, 1, 3, 2) key = cape_embed(key, p_in) # key f_k @ p_in query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() # self-ttn (bm) l c x (bm) l c -> (bm) l c # cross-ttn (bm) l c x b (nl) c -> (bm) l c # reuse 2d attention for multiview attention # self-ttn b (ml) c x b (ml) c -> b (ml) c # cross-ttn b (ml) c x b (nl) c -> b (ml) c hidden_states = xformers.ops.memory_efficient_attention( # query: (bm) l c -> b (ml) c; key: b (nl) c query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if posemb is not None: # reshape back hidden_states = einops.rearrange(hidden_states, 'b (t_out l) d -> (b t_out) l d', t_out=t_out) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class LoRAXFormersAttnProcessor(nn.Module): r""" Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers. Args: hidden_size (`int`, *optional*): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. network_alpha (`int`, *optional*): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ def __init__( self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None ): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank self.attention_op = attention_op self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class LoRAAttnProcessor2_0(nn.Module): r""" Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product attention. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*): The number of channels in the `encoder_hidden_states`. rank (`int`, defaults to 4): The dimension of the LoRA update matrices. network_alpha (`int`, *optional*): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CustomDiffusionXFormersAttnProcessor(nn.Module): r""" Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. Args: train_kv (`bool`, defaults to `True`): Whether to newly train the key and value matrices corresponding to the text features. train_q_out (`bool`, defaults to `True`): Whether to newly train query matrices corresponding to the latent image features. hidden_size (`int`, *optional*, defaults to `None`): The hidden size of the attention layer. cross_attention_dim (`int`, *optional*, defaults to `None`): The number of channels in the `encoder_hidden_states`. out_bias (`bool`, defaults to `True`): Whether to include the bias parameter in `train_q_out`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. attention_op (`Callable`, *optional*, defaults to `None`): The base [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. """ def __init__( self, train_kv=True, train_q_out=False, hidden_size=None, cross_attention_dim=None, out_bias=True, dropout=0.0, attention_op: Optional[Callable] = None, ): super().__init__() self.train_kv = train_kv self.train_q_out = train_q_out self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.attention_op = attention_op # `_custom_diffusion` id for easy serialization and loading. if self.train_kv: self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) if self.train_q_out: self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) self.to_out_custom_diffusion = nn.ModuleList([]) self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) self.to_out_custom_diffusion.append(nn.Dropout(dropout)) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: query = self.to_q_custom_diffusion(hidden_states) else: query = attn.to_q(hidden_states) if encoder_hidden_states is None: crossattn = False encoder_hidden_states = hidden_states else: crossattn = True if attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.train_kv: key = self.to_k_custom_diffusion(encoder_hidden_states) value = self.to_v_custom_diffusion(encoder_hidden_states) else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) if crossattn: detach = torch.ones_like(key) detach[:, :1, :] = detach[:, :1, :] * 0.0 key = detach * key + (1 - detach) * key.detach() value = detach * value + (1 - detach) * value.detach() query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention( query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale ) hidden_states = hidden_states.to(query.dtype) hidden_states = attn.batch_to_head_dim(hidden_states) if self.train_q_out: # linear proj hidden_states = self.to_out_custom_diffusion[0](hidden_states) # dropout hidden_states = self.to_out_custom_diffusion[1](hidden_states) else: # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class SlicedAttnProcessor: r""" Processor for implementing sliced attention. Args: slice_size (`int`, *optional*): The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and `attention_head_dim` must be a multiple of the `slice_size`. """ def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) for i in range(batch_size_attention // self.slice_size): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class SlicedAttnAddedKVProcessor: r""" Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder. Args: slice_size (`int`, *optional*): The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and `attention_head_dim` must be a multiple of the `slice_size`. """ def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) else: key = encoder_hidden_states_key_proj value = encoder_hidden_states_value_proj batch_size_attention, query_tokens, _ = query.shape hidden_states = torch.zeros( (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) for i in range(batch_size_attention // self.slice_size): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states AttentionProcessor = Union[ AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor, SlicedAttnProcessor, AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0, XFormersAttnAddedKVProcessor, LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, LoRAAttnAddedKVProcessor, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, ] class SpatialNorm(nn.Module): """ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 """ def __init__( self, f_channels, zq_channels, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) def forward(self, f, zq): f_size = f.shape[-2:] zq = F.interpolate(zq, size=f_size, mode="nearest") norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f ================================================ FILE: 6DoF/diffusers/models/autoencoder_kl.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, apply_forward_hook from .attention_processor import AttentionProcessor, AttnProcessor from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder @dataclass class AutoencoderKLOutput(BaseOutput): """ Output of AutoencoderKL encoding method. Args: latent_dist (`DiagonalGaussianDistribution`): Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. `DiagonalGaussianDistribution` allows for sampling latents from the distribution. """ latent_dist: "DiagonalGaussianDistribution" class AutoencoderKL(ModelMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, ): super().__init__() # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, ) # pass init params to Decoder self.decoder = Decoder( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) self.use_slicing = False self.use_tiling = False # only relevant if vae tiling is enabled self.tile_sample_min_size = self.config.sample_size sample_size = ( self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (Encoder, Decoder)): module.gradient_checkpointing = value def enable_tiling(self, use_tiling: bool = True): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.use_tiling = use_tiling def disable_tiling(self): r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.enable_tiling(False) def enable_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True def disable_slicing(self): r""" Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): return self.tiled_encode(x, return_dict=return_dict) if self.use_slicing and x.shape[0] > 1: encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): return self.tiled_decode(z, return_dict=return_dict) z = self.post_quant_conv(z) dec = self.decoder(z) if not return_dict: return (dec,) return DecoderOutput(sample=dec) @apply_forward_hook def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: if self.use_slicing and z.shape[0] > 1: decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: decoded = self._decode(z).sample if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) def blend_v(self, a, b, blend_extent): blend_extent = min(a.shape[2], b.shape[2], blend_extent) for y in range(blend_extent): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b def blend_h(self, a, b, blend_extent): blend_extent = min(a.shape[3], b.shape[3], blend_extent) for x in range(blend_extent): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the output, but they should be much less noticeable. Args: x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent # Split the image into 512x512 tiles and encode them separately. rows = [] for i in range(0, x.shape[2], overlap_size): row = [] for j in range(0, x.shape[3], overlap_size): tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] tile = self.encoder(tile) tile = self.quant_conv(tile) row.append(tile) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) moments = torch.cat(result_rows, dim=2) posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: r""" Decode a batch of images using a tiled decoder. Args: z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. Returns: [`~models.vae.DecoderOutput`] or `tuple`: If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is returned. """ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) row_limit = self.tile_sample_min_size - blend_extent # Split z into overlapping 64x64 tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] for i in range(0, z.shape[2], overlap_size): row = [] for j in range(0, z.shape[3], overlap_size): tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] tile = self.post_quant_conv(tile) decoded = self.decoder(tile) row.append(decoded) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) dec = torch.cat(result_rows, dim=2) if not return_dict: return (dec,) return DecoderOutput(sample=dec) def forward( self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[DecoderOutput, torch.FloatTensor]: r""" Args: sample (`torch.FloatTensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec) ================================================ FILE: 6DoF/diffusers/models/controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) from .unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ControlNetOutput(BaseOutput): """ The output of [`ControlNetModel`]. Args: down_block_res_samples (`tuple[torch.Tensor]`): A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be used to condition the original UNet's downsampling activations. mid_down_block_re_sample (`torch.Tensor`): The activation of the midde block (the lowest sample resolution). Each tensor should be of shape `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. Output can be used to condition the original UNet's middle block activation. """ down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor class ControlNetConditioningEmbedding(nn.Module): """ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full model) to encode image-space conditions ... into feature maps ..." """ def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 96, 256), ): super().__init__() self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class ControlNetModel(ModelMixin, ConfigMixin): """ A ControlNet model. Args: in_channels (`int`, defaults to 4): The number of channels in the input sample. flip_sin_to_cos (`bool`, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, defaults to 2): The number of layers per block. downsample_padding (`int`, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, defaults to 1): The scale factor to use for the mid block. act_fn (`str`, defaults to "silu"): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If None, normalization and activation layers is skipped in post-processing. norm_eps (`float`, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): The dimension of the attention heads. use_linear_projection (`bool`, defaults to `False`): class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. num_class_embeds (`int`, *optional*, defaults to 0): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. upcast_attention (`bool`, defaults to `False`): resnet_time_scale_shift (`str`, defaults to `"default"`): Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `conditioning_embedding` layer. global_pool_conditions (`bool`, defaults to `False`): """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 4, conditioning_channels: int = 3, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, ): super().__init__() # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) for _ in range(layers_per_block): controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channel = block_out_channels[-1] controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock2DCrossAttn( in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, ): r""" Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied where applicable. """ controlnet = cls( in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, attention_head_dim=unet.config.attention_head_dim, num_attention_heads=unet.config.num_attention_heads, use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, ) if load_weights_from_unet: controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) if controlnet.class_embedding: controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) return controlnet @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.FloatTensor, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: """ The [`ControlNetModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor. timestep (`Union[torch.Tensor, float, int]`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. controlnet_cond (`torch.FloatTensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`): cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. guess_mode (`bool`, defaults to `False`): In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: [`~models.controlnet.ControlNetOutput`] **or** `tuple`: If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # check channel order channel_order = self.config.controlnet_conditioning_channel_order if channel_order == "rgb": # in rgb order by default ... elif channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) else: raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # 2. pre-process sample = self.conv_in(sample) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample = sample + controlnet_cond # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, ) # 5. Control net blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples ] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: return (down_block_res_samples, mid_block_res_sample) return ControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module ================================================ FILE: 6DoF/diffusers/models/controlnet_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .modeling_flax_utils import FlaxModelMixin from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, ) @flax.struct.dataclass class FlaxControlNetOutput(BaseOutput): """ The output of [`FlaxControlNetModel`]. Args: down_block_res_samples (`jnp.ndarray`): mid_block_res_sample (`jnp.ndarray`): """ down_block_res_samples: jnp.ndarray mid_block_res_sample: jnp.ndarray class FlaxControlNetConditioningEmbedding(nn.Module): conditioning_embedding_channels: int block_out_channels: Tuple[int] = (16, 32, 96, 256) dtype: jnp.dtype = jnp.float32 def setup(self): self.conv_in = nn.Conv( self.block_out_channels[0], kernel_size=(3, 3), padding=((1, 1), (1, 1)), dtype=self.dtype, ) blocks = [] for i in range(len(self.block_out_channels) - 1): channel_in = self.block_out_channels[i] channel_out = self.block_out_channels[i + 1] conv1 = nn.Conv( channel_in, kernel_size=(3, 3), padding=((1, 1), (1, 1)), dtype=self.dtype, ) blocks.append(conv1) conv2 = nn.Conv( channel_out, kernel_size=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)), dtype=self.dtype, ) blocks.append(conv2) self.blocks = blocks self.conv_out = nn.Conv( self.conditioning_embedding_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) def __call__(self, conditioning): embedding = self.conv_in(conditioning) embedding = nn.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = nn.silu(embedding) embedding = self.conv_out(embedding) return embedding @flax_register_to_config class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" A ControlNet model. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods implemented for all models (such as downloading or saving). This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its general usage and behavior. Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: sample_size (`int`, *optional*): The size of the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): The tuple of downsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int` or `Tuple[int]`, *optional*): The number of attention heads. cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `conditioning_embedding` layer. """ sample_size: int = 32 in_channels: int = 4 down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ) only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: Union[int, Tuple[int]] = 8 num_attention_heads: Optional[Union[int, Tuple[int]]] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 controlnet_conditioning_channel_order: str = "rgb" conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256) def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8) controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"] def setup(self): block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = self.num_attention_heads or self.attention_head_dim # input self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # time self.time_proj = FlaxTimesteps( block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift ) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) self.controlnet_cond_embedding = FlaxControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=self.conditioning_embedding_out_channels, ) only_cross_attention = self.only_cross_attention if isinstance(only_cross_attention, bool): only_cross_attention = (only_cross_attention,) * len(self.down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(self.down_block_types) # down down_blocks = [] controlnet_down_blocks = [] output_channel = block_out_channels[0] controlnet_block = nn.Conv( output_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 if down_block_type == "CrossAttnDownBlock2D": down_block = FlaxCrossAttnDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, num_attention_heads=num_attention_heads[i], add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], dtype=self.dtype, ) else: down_block = FlaxDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) for _ in range(self.layers_per_block): controlnet_block = nn.Conv( output_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv( output_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) controlnet_down_blocks.append(controlnet_block) self.down_blocks = down_blocks self.controlnet_down_blocks = controlnet_down_blocks # mid mid_block_channel = block_out_channels[-1] self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=mid_block_channel, dropout=self.dropout, num_attention_heads=num_attention_heads[-1], use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) self.controlnet_mid_block = nn.Conv( mid_block_channel, kernel_size=(1, 1), padding="VALID", kernel_init=nn.initializers.zeros_init(), bias_init=nn.initializers.zeros_init(), dtype=self.dtype, ) def __call__( self, sample, timesteps, encoder_hidden_states, controlnet_cond, conditioning_scale: float = 1.0, return_dict: bool = True, train: bool = False, ) -> Union[FlaxControlNetOutput, Tuple]: r""" Args: sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor conditioning_scale: (`float`) the scale factor for controlnet outputs return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ channel_order = self.controlnet_conditioning_channel_order if channel_order == "bgr": controlnet_cond = jnp.flip(controlnet_cond, axis=1) # 1. time if not isinstance(timesteps, jnp.ndarray): timesteps = jnp.array([timesteps], dtype=jnp.int32) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: timesteps = timesteps.astype(dtype=jnp.float32) timesteps = jnp.expand_dims(timesteps, 0) t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) # 2. pre-process sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) controlnet_cond = jnp.transpose(controlnet_cond, (0, 2, 3, 1)) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample += controlnet_cond # 3. down down_block_res_samples = (sample,) for down_block in self.down_blocks: if isinstance(down_block, FlaxCrossAttnDownBlock2D): sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples # 4. mid sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) # 5. contronet blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample *= conditioning_scale if not return_dict: return (down_block_res_samples, mid_block_res_sample) return FlaxControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) ================================================ FILE: 6DoF/diffusers/models/cross_attention.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..utils import deprecate from .attention_processor import ( # noqa: F401 Attention, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor2_0, LoRAAttnProcessor, LoRALinearLayer, LoRAXFormersAttnProcessor, SlicedAttnAddedKVProcessor, SlicedAttnProcessor, XFormersAttnProcessor, ) from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401 deprecate( "cross_attention", "0.20.0", "Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.", standard_warn=False, ) AttnProcessor = AttentionProcessor class CrossAttention(Attention): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class CrossAttnProcessor(AttnProcessorRename): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class LoRACrossAttnProcessor(LoRAAttnProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class CrossAttnAddedKVProcessor(AttnAddedKVProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class XFormersCrossAttnProcessor(XFormersAttnProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class LoRAXFormersCrossAttnProcessor(LoRAXFormersAttnProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class SlicedCrossAttnProcessor(SlicedAttnProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) class SlicedCrossAttnAddedKVProcessor(SlicedAttnAddedKVProcessor): def __init__(self, *args, **kwargs): deprecation_message = f"{self.__class__.__name__} is deprecated and will be removed in `0.20.0`. Please use `from diffusers.models.attention_processor import {''.join(self.__class__.__name__.split('Cross'))} instead." deprecate("cross_attention", "0.20.0", deprecation_message, standard_warn=False) super().__init__(*args, **kwargs) ================================================ FILE: 6DoF/diffusers/models/dual_transformer_2d.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional from torch import nn from .transformer_2d import Transformer2DModel, Transformer2DModelOutput class DualTransformer2DModel(nn.Module): """ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input and output. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. """ def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, ): super().__init__() self.transformers = nn.ModuleList( [ Transformer2DModel( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, in_channels=in_channels, num_layers=num_layers, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attention_bias=attention_bias, sample_size=sample_size, num_vector_embeds=num_vector_embeds, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, ) for _ in range(2) ] ) # Variables that can be set by a pipeline: # The ratio of transformer1 to transformer2's output states to be combined during inference self.mix_ratio = 0.5 # The shape of `encoder_hidden_states` is expected to be # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` self.condition_lengths = [77, 257] # Which transformer to use to encode which condition. # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` self.transformer_index_for_condition = [1, 0] def forward( self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, cross_attention_kwargs=None, return_dict: bool = True, ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. attention_mask (`torch.FloatTensor`, *optional*): Optional attention mask to be applied in Attention return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ input_states = hidden_states encoded_states = [] tokens_start = 0 # attention_mask is not used yet for i in range(2): # for each of the two transformers, pass the corresponding condition tokens condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] transformer_index = self.transformer_index_for_condition[i] encoded_state = self.transformers[transformer_index]( input_states, encoder_hidden_states=condition_state, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] encoded_states.append(encoded_state - input_states) tokens_start += self.condition_lengths[i] output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) output_states = output_states + input_states if not return_dict: return (output_states,) return Transformer2DModelOutput(sample=output_states) ================================================ FILE: 6DoF/diffusers/models/embeddings.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional import numpy as np import torch from torch import nn from .activations import get_activation def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class PatchEmbed(nn.Module): """2D Image to Patch Embedding""" def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) def forward(self, latent): latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: latent = self.norm(latent) return latent + self.pos_embed class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, ) return t_emb class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__( self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.log = log self.flip_sin_to_cos = flip_sin_to_cos if set_W_to_weight: # to delete later self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.weight = self.W def forward(self, x): if self.log: x = torch.log(x) x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi if self.flip_sin_to_cos: out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) else: out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out class ImagePositionalEmbeddings(nn.Module): """ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the height and width of the latent space. For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 For VQ-diffusion: Output vector embeddings are used as input for the transformer. Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. Args: num_embed (`int`): Number of embeddings for the latent pixels embeddings. height (`int`): Height of the latent image i.e. the number of height embeddings. width (`int`): Width of the latent image i.e. the number of width embeddings. embed_dim (`int`): Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. """ def __init__( self, num_embed: int, height: int, width: int, embed_dim: int, ): super().__init__() self.height = height self.width = width self.num_embed = num_embed self.embed_dim = embed_dim self.emb = nn.Embedding(self.num_embed, embed_dim) self.height_emb = nn.Embedding(self.height, embed_dim) self.width_emb = nn.Embedding(self.width, embed_dim) def forward(self, index): emb = self.emb(index) height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) # 1 x H x D -> 1 x H x 1 x D height_emb = height_emb.unsqueeze(2) width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) # 1 x W x D -> 1 x 1 x W x D width_emb = width_emb.unsqueeze(1) pos_emb = height_emb + width_emb # 1 x H x W x D -> 1 x L xD pos_emb = pos_emb.view(1, self.height * self.width, -1) emb = emb + pos_emb[:, : emb.shape[1], :] return emb class LabelEmbedding(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. Args: num_classes (`int`): The number of classes. hidden_size (`int`): The size of the vector embeddings. dropout_prob (`float`): The probability of dropping a label. """ def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): """ Drops labels to enable classifier-free guidance. """ if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob else: drop_ids = torch.tensor(force_drop_ids == 1) labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels: torch.LongTensor, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings class TextImageProjection(nn.Module): def __init__( self, text_embed_dim: int = 1024, image_embed_dim: int = 768, cross_attention_dim: int = 768, num_image_text_embeds: int = 10, ): super().__init__() self.num_image_text_embeds = num_image_text_embeds self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): batch_size = text_embeds.shape[0] # image image_text_embeds = self.image_embeds(image_embeds) image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) # text text_embeds = self.text_proj(text_embeds) return torch.cat([image_text_embeds, text_embeds], dim=1) class ImageProjection(nn.Module): def __init__( self, image_embed_dim: int = 768, cross_attention_dim: int = 768, num_image_text_embeds: int = 32, ): super().__init__() self.num_image_text_embeds = num_image_text_embeds self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.FloatTensor): batch_size = image_embeds.shape[0] # image image_embeds = self.image_embeds(image_embeds) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) image_embeds = self.norm(image_embeds) return image_embeds class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) def forward(self, timestep, class_labels, hidden_dtype=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) class_labels = self.class_embedder(class_labels) # (N, D) conditioning = timesteps_emb + class_labels # (N, D) return conditioning class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() self.norm1 = nn.LayerNorm(encoder_dim) self.pool = AttentionPooling(num_heads, encoder_dim) self.proj = nn.Linear(encoder_dim, time_embed_dim) self.norm2 = nn.LayerNorm(time_embed_dim) def forward(self, hidden_states): hidden_states = self.norm1(hidden_states) hidden_states = self.pool(hidden_states) hidden_states = self.proj(hidden_states) hidden_states = self.norm2(hidden_states) return hidden_states class TextImageTimeEmbedding(nn.Module): def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) self.text_norm = nn.LayerNorm(time_embed_dim) self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): # text time_text_embeds = self.text_proj(text_embeds) time_text_embeds = self.text_norm(time_text_embeds) # image time_image_embeds = self.image_proj(image_embeds) return time_image_embeds + time_text_embeds class ImageTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim) def forward(self, image_embeds: torch.FloatTensor): # image time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_norm(time_image_embeds) return time_image_embeds class ImageHintTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim) self.input_hint_block = nn.Sequential( nn.Conv2d(3, 16, 3, padding=1), nn.SiLU(), nn.Conv2d(16, 16, 3, padding=1), nn.SiLU(), nn.Conv2d(16, 32, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(32, 32, 3, padding=1), nn.SiLU(), nn.Conv2d(32, 96, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(96, 96, 3, padding=1), nn.SiLU(), nn.Conv2d(96, 256, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(256, 4, 3, padding=1), ) def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor): # image time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_norm(time_image_embeds) hint = self.input_hint_block(hint) return time_image_embeds, hint class AttentionPooling(nn.Module): # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 def __init__(self, num_heads, embed_dim, dtype=None): super().__init__() self.dtype = dtype self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.num_heads = num_heads self.dim_per_head = embed_dim // self.num_heads def forward(self, x): bs, length, width = x.size() def shape(x): # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, -1, self.num_heads, self.dim_per_head) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) x = x.transpose(1, 2) return x class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) # (bs*n_heads, class_token_length, dim_per_head) q = shape(self.q_proj(class_token)) # (bs*n_heads, length+class_token_length, dim_per_head) k = shape(self.k_proj(x)) v = shape(self.v_proj(x)) # (bs*n_heads, class_token_length, length+class_token_length): scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) # (bs*n_heads, dim_per_head, class_token_length) a = torch.einsum("bts,bcs->bct", weight, v) # (bs, length+1, width) a = a.reshape(bs, -1, 1).transpose(1, 2) return a[:, 0, :] # cls_token ================================================ FILE: 6DoF/diffusers/models/embeddings_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import flax.linen as nn import jax.numpy as jnp def get_sinusoidal_embeddings( timesteps: jnp.ndarray, embedding_dim: int, freq_shift: float = 1, min_timescale: float = 1, max_timescale: float = 1.0e4, flip_sin_to_cos: bool = False, scale: float = 1.0, ) -> jnp.ndarray: """Returns the positional encoding (same as Tensor2Tensor). Args: timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim: The number of output channels. min_timescale: The smallest time unit (should probably be 0.0). max_timescale: The largest time unit. Returns: a Tensor of timing signals [N, num_channels] """ assert timesteps.ndim == 1, "Timesteps should be a 1d-array" assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" num_timescales = float(embedding_dim // 2) log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) # scale embeddings scaled_time = scale * emb if flip_sin_to_cos: signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) else: signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) return signal class FlaxTimestepEmbedding(nn.Module): r""" Time step Embedding Module. Learns embeddings for input time steps. Args: time_embed_dim (`int`, *optional*, defaults to `32`): Time step embedding dimension dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ time_embed_dim: int = 32 dtype: jnp.dtype = jnp.float32 @nn.compact def __call__(self, temb): temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) temb = nn.silu(temb) temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) return temb class FlaxTimesteps(nn.Module): r""" Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 Args: dim (`int`, *optional*, defaults to `32`): Time step embedding dimension """ dim: int = 32 flip_sin_to_cos: bool = False freq_shift: float = 1 @nn.compact def __call__(self, timesteps): return get_sinusoidal_embeddings( timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift ) ================================================ FILE: 6DoF/diffusers/models/modeling_flax_pytorch_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch - Flax general utilities.""" import re import jax.numpy as jnp from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey from ..utils import logging logger = logging.get_logger(__name__) def rename_key(key): regex = r"\w+[.]\d+" pats = re.findall(regex, key) for pat in pats: key = key.replace(pat, "_".join(pat.split("."))) return key ##################### # PyTorch => Flax # ##################### # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" # conv norm or layer norm renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) if ( any("norm" in str_ for str_ in pt_tuple_key) and (pt_tuple_key[-1] == "bias") and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) ): renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) return renamed_pt_tuple_key, pt_tensor elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) return renamed_pt_tuple_key, pt_tensor # embedding if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) return renamed_pt_tuple_key, pt_tensor # conv layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: pt_tensor = pt_tensor.transpose(2, 3, 1, 0) return renamed_pt_tuple_key, pt_tensor # linear layer renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) if pt_tuple_key[-1] == "weight": pt_tensor = pt_tensor.T return renamed_pt_tuple_key, pt_tensor # old PyTorch layer norm weight renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) if pt_tuple_key[-1] == "gamma": return renamed_pt_tuple_key, pt_tensor # old PyTorch layer norm bias renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) if pt_tuple_key[-1] == "beta": return renamed_pt_tuple_key, pt_tensor return pt_tuple_key, pt_tensor def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): # Step 1: Convert pytorch tensor to numpy pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} # Step 2: Since the model is stateless, get random Flax params random_flax_params = flax_model.init_weights(PRNGKey(init_key)) random_flax_state_dict = flatten_dict(random_flax_params) flax_state_dict = {} # Need to change some parameters name to match Flax names for pt_key, pt_tensor in pt_state_dict.items(): renamed_pt_key = rename_key(pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) # Correctly rename weight parameters flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) if flax_key in random_flax_state_dict: if flax_tensor.shape != random_flax_state_dict[flax_key].shape: raise ValueError( f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." ) # also add unexpected weight so that warning is thrown flax_state_dict[flax_key] = jnp.asarray(flax_tensor) return unflatten_dict(flax_state_dict) ================================================ FILE: 6DoF/diffusers/models/modeling_flax_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from pickle import UnpicklingError from typing import Any, Dict, Union import jax import jax.numpy as jnp import msgpack.exceptions from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError from .. import __version__, is_torch_available from ..utils import ( CONFIG_NAME, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging, ) from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax logger = logging.get_logger(__name__) class FlaxModelMixin: r""" Base class for all Flax models. [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and saving models. - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _flax_internal_args = ["name", "parent", "dtype"] @classmethod def _from_config(cls, config, **kwargs): """ All context managers that the model should be initialized under go here. """ return cls(config, **kwargs) def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: """ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. """ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 def conditional_cast(param): if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): param = param.astype(dtype) return param if mask is None: return jax.tree_map(conditional_cast, params) flat_params = flatten_dict(params) flat_mask, _ = jax.tree_flatten(mask) for masked, key in zip(flat_mask, flat_params.keys()): if masked: param = flat_params[key] flat_params[key] = conditional_cast(param) return unflatten_dict(flat_params) def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast the `params` in place. This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` for params you want to cast, and `False` for those you want to skip. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # load model >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision >>> params = model.to_bf16(params) >>> # If you don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } >>> mask = traverse_util.unflatten_dict(mask) >>> params = model.to_bf16(params, mask) ```""" return self._cast_floating_to(params, jnp.bfloat16, mask) def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` for params you want to cast, and `False` for those you want to skip. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model params will be in fp32, to illustrate the use of this method, >>> # we'll first cast to fp16 and back to fp32 >>> params = model.to_f16(params) >>> # now cast back to fp32 >>> params = model.to_fp32(params) ```""" return self._cast_floating_to(params, jnp.float32, mask) def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the `params` in place. This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full half-precision training or to save weights in float16 for inference in order to save memory and improve speed. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` for params you want to cast, and `False` for those you want to skip. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # load model >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # By default, the model params will be in fp32, to cast these to float16 >>> params = model.to_fp16(params) >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } >>> mask = traverse_util.unflatten_dict(mask) >>> params = model.to_fp16(params, mask) ```""" return self._cast_floating_to(params, jnp.float16, mask) def init_weights(self, rng: jax.random.KeyArray) -> Dict: raise NotImplementedError(f"init_weights method has to be implemented for {self}") @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], dtype: jnp.dtype = jnp.float32, *model_args, **kwargs, ): r""" Instantiate a pretrained Flax model from a pretrained model configuration. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved using [`~FlaxModelMixin.save_pretrained`]. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified, all the computation will be performed with the given `dtype`. This only specifies the dtype of the *computation* and does not influence the dtype of model parameters. If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and [`~FlaxModelMixin.to_bf16`]. model_args (sequence of positional arguments, *optional*): All remaining positional arguments are passed to the underlying model's `__init__` method. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. from_pt (`bool`, *optional*, defaults to `False`): Load the model weights from a PyTorch checkpoint save file. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it is loaded) and initiate the model (for example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done). - If a configuration is not provided, `kwargs` are first passed to the configuration class initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds to a configuration attribute is used to override said attribute with the supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute are passed to the underlying model's `__init__` function. Examples: ```python >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co and cache. >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") ``` If you get the error message below, you need to finetune the weights for your downstream task: ```bash Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) from_pt = kwargs.pop("from_pt", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) user_agent = { "diffusers": __version__, "file_type": "model", "framework": "flax", } # Load config if we don't provide a configuration config_path = config if config is not None else pretrained_model_name_or_path model, model_kwargs = cls.from_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, # model args dtype=dtype, **kwargs, ) # Load model pretrained_path_with_subfolder = ( pretrained_model_name_or_path if subfolder is None else os.path.join(pretrained_model_name_or_path, subfolder) ) if os.path.isdir(pretrained_path_with_subfolder): if from_pt: if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): raise EnvironmentError( f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} " ) model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME) # Check if pytorch weights exist instead elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): raise EnvironmentError( f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model" " using `from_pt=True`." ) else: raise EnvironmentError( f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"{pretrained_path_with_subfolder}." ) else: try: model_file = hf_hub_download( pretrained_model_name_or_path, filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, subfolder=subfolder, revision=revision, ) except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " "login`." ) except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " "this model name. Check the model page at " f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}." ) except HTTPError as err: raise EnvironmentError( f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" f"{err}" ) except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your" " internet connection or see how to run the library in offline mode at" " 'https://huggingface.co/docs/transformers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." ) if from_pt: if is_torch_available(): from .modeling_utils import load_state_dict else: raise EnvironmentError( "Can't load the model in PyTorch format because PyTorch is not installed. " "Please, install PyTorch or use native Flax weights." ) # Step 1: Get the pytorch file pytorch_model_file = load_state_dict(model_file) # Step 2: Convert the weights state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) else: try: with open(model_file, "rb") as state_f: state = from_bytes(cls, state_f.read()) except (UnpicklingError, msgpack.exceptions.ExtraData) as e: try: with open(model_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please" " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" " folder you cloned." ) else: raise ValueError from e except (UnicodeDecodeError, ValueError): raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") # make sure all arrays are stored as jnp.ndarray # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) # flatten dicts state = flatten_dict(state) params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0)) required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) shape_state = flatten_dict(unfreeze(params_shape_tree)) missing_keys = required_params - set(state.keys()) unexpected_keys = set(state.keys()) - required_params if missing_keys: logger.warning( f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " "Make sure to call model.init_weights to initialize the missing weights." ) cls._missing_keys = missing_keys for key in state.keys(): if key in shape_state and state[key].shape != shape_state[key].shape: raise ValueError( f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " ) # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: del state[unexpected_key] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" " with another architecture." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" f" was trained on, you can already use {model.__class__.__name__} for predictions without further" " training." ) return model, unflatten_dict(state) def save_pretrained( self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict], is_main_process: bool = True, ): """ Save a model and its configuration file to a directory so that it can be reloaded using the [`~FlaxModelMixin.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save a model and its configuration file to. Will be created if it doesn't exist. params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) model_to_save = self # Attach architecture to the config # Save the config if is_main_process: model_to_save.save_config(save_directory) # save model output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) with open(output_model_file, "wb") as f: model_bytes = to_bytes(params) f.write(model_bytes) logger.info(f"Model weights saved in {output_model_file}") ================================================ FILE: 6DoF/diffusers/models/modeling_pytorch_flax_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch - Flax general utilities.""" from pickle import UnpicklingError import jax import jax.numpy as jnp import numpy as np from flax.serialization import from_bytes from flax.traverse_util import flatten_dict from ..utils import logging logger = logging.get_logger(__name__) ##################### # Flax => PyTorch # ##################### # from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352 def load_flax_checkpoint_in_pytorch_model(pt_model, model_file): try: with open(model_file, "rb") as flax_state_f: flax_state = from_bytes(None, flax_state_f.read()) except UnpicklingError as e: try: with open(model_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please" " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" " folder you cloned." ) else: raise ValueError from e except (UnicodeDecodeError, ValueError): raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") return load_flax_weights_in_pytorch_model(pt_model, flax_state) def load_flax_weights_in_pytorch_model(pt_model, flax_state): """Load flax checkpoints in a PyTorch model""" try: import torch # noqa: F401 except ImportError: logger.error( "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see" " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation" " instructions." ) raise # check if we have bf16 weights is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() if any(is_type_bf16): # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16 # and bf16 is not fully supported in PT yet. logger.warning( "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " "before loading those in PyTorch model." ) flax_state = jax.tree_util.tree_map( lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state ) pt_model.base_model_prefix = "" flax_state_dict = flatten_dict(flax_state, sep=".") pt_model_dict = pt_model.state_dict() # keep track of unexpected & missing keys unexpected_keys = [] missing_keys = set(pt_model_dict.keys()) for flax_key_tuple, flax_tensor in flax_state_dict.items(): flax_key_tuple_array = flax_key_tuple.split(".") if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4: flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1)) elif flax_key_tuple_array[-1] == "kernel": flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] flax_tensor = flax_tensor.T elif flax_key_tuple_array[-1] == "scale": flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"] if "time_embedding" not in flax_key_tuple_array: for i, flax_key_tuple_string in enumerate(flax_key_tuple_array): flax_key_tuple_array[i] = ( flax_key_tuple_string.replace("_0", ".0") .replace("_1", ".1") .replace("_2", ".2") .replace("_3", ".3") .replace("_4", ".4") .replace("_5", ".5") .replace("_6", ".6") .replace("_7", ".7") .replace("_8", ".8") .replace("_9", ".9") ) flax_key = ".".join(flax_key_tuple_array) if flax_key in pt_model_dict: if flax_tensor.shape != pt_model_dict[flax_key].shape: raise ValueError( f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected " f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." ) else: # add weight to pytorch dict flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) # remove from missing keys missing_keys.remove(flax_key) else: # weight is not expected by PyTorch model unexpected_keys.append(flax_key) pt_model.load_state_dict(pt_model_dict) # re-transform missing_keys to list missing_keys = list(missing_keys) if len(unexpected_keys) > 0: logger.warning( "Some weights of the Flax model were not used when initializing the PyTorch model" f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture" " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This" f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect" " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a" " FlaxBertForSequenceClassification model)." ) if len(missing_keys) > 0: logger.warning( f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly" f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to" " use it for predictions and inference." ) return pt_model ================================================ FILE: 6DoF/diffusers/models/modeling_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import itertools import os import re from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch import Tensor, device, nn from .. import __version__ from ..utils import ( CONFIG_NAME, DIFFUSERS_CACHE, FLAX_WEIGHTS_NAME, HF_HUB_OFFLINE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, _get_model_file, deprecate, is_accelerate_available, is_safetensors_available, is_torch_version, logging, ) logger = logging.get_logger(__name__) if is_torch_version(">=", "1.9.0"): _LOW_CPU_MEM_USAGE_DEFAULT = True else: _LOW_CPU_MEM_USAGE_DEFAULT = False if is_accelerate_available(): import accelerate from accelerate.utils import set_module_tensor_to_device from accelerate.utils.versions import is_torch_version if is_safetensors_available(): import safetensors def get_parameter_device(parameter: torch.nn.Module): try: parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) return next(parameters_and_buffers).device except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].device def get_parameter_dtype(parameter: torch.nn.Module): try: params = tuple(parameter.parameters()) if len(params) > 0: return params[0].dtype buffers = tuple(parameter.buffers()) if len(buffers) > 0: return buffers[0].dtype except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].dtype def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ try: if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant): return torch.load(checkpoint_file, map_location="cpu") else: return safetensors.torch.load_file(checkpoint_file, device="cpu") except Exception as e: try: with open(checkpoint_file) as f: if f.read().startswith("version"): raise OSError( "You seem to have cloned a repository without having git-lfs installed. Please install " "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " "you cloned." ) else: raise ValueError( f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " "model. Make sure you have saved the model properly." ) from e except (UnicodeDecodeError, ValueError): raise OSError( f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." ) def _load_state_dict_into_model(model_to_load, state_dict): # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it state_dict = state_dict.copy() error_msgs = [] # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module: torch.nn.Module, prefix=""): args = (state_dict, prefix, {}, True, [], [], error_msgs) module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(model_to_load) return error_msgs class ModelMixin(torch.nn.Module): r""" Base class for all models. [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and saving models. - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None def __init__(self): super().__init__() def __getattr__(self, name: str) -> Any: """The only reason we overwrite `getattr` here is to gracefully deprecate accessing config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__': https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module """ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name) is_attribute = name in self.__dict__ if is_in_config and not is_attribute: deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'." deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3) return self._internal_dict[name] # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module return super().__getattr__(name) @property def is_gradient_checkpointing(self) -> bool: """ Whether gradient checkpointing is activated for this model or not. """ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) def enable_gradient_checkpointing(self): """ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). """ if not self._supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") self.apply(partial(self._set_gradient_checkpointing, value=True)) def disable_gradient_checkpointing(self): """ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or *checkpoint activations* in other frameworks). """ if self._supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) def set_use_memory_efficient_attention_xformers( self, valid: bool, attention_op: Optional[Callable] = None ) -> None: # Recursively walk through all the children. # Any children which exposes the set_use_memory_efficient_attention_xformers method # gets the message def fn_recursive_set_mem_eff(module: torch.nn.Module): if hasattr(module, "set_use_memory_efficient_attention_xformers"): module.set_use_memory_efficient_attention_xformers(valid, attention_op) for child in module.children(): fn_recursive_set_mem_eff(child) for module in self.children(): if isinstance(module, torch.nn.Module): fn_recursive_set_mem_eff(module) def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): r""" Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed. ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent. Parameters: attention_op (`Callable`, *optional*): Override the default `None` operator for use as `op` argument to the [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) function of xFormers. Examples: ```py >>> import torch >>> from diffusers import UNet2DConditionModel >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp >>> model = UNet2DConditionModel.from_pretrained( ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16 ... ) >>> model = model.to("cuda") >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) ``` """ self.set_use_memory_efficient_attention_xformers(True, attention_op) def disable_xformers_memory_efficient_attention(self): r""" Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). """ self.set_use_memory_efficient_attention_xformers(False) def save_pretrained( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Callable = None, safe_serialization: bool = False, variant: Optional[str] = None, ): """ Save a model and its configuration file to a directory so that it can be reloaded using the [`~models.ModelMixin.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save a model and its configuration file to. Will be created if it doesn't exist. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. """ if safe_serialization and not is_safetensors_available(): raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) model_to_save = self # Attach architecture to the config # Save the config if is_main_process: model_to_save.save_config(save_directory) # Save the model state_dict = model_to_save.state_dict() weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = _add_variant(weights_name, variant) # Save the model if safe_serialization: safetensors.torch.save_file( state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"} ) else: torch.save(state_dict, os.path.join(save_directory, weights_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a pretrained PyTorch model from a pretrained model configuration. The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To train the model, set it back in training mode with `model.train()`. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`~ModelMixin.save_pretrained`]. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info (`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. from_flax (`bool`, *optional*, defaults to `False`): Load the model weights from a Flax checkpoint save file. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if `device_map` contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. variant (`str`, *optional*): Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. You can also activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. Example: ```py from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") ``` If you get the error message below, you need to finetune the weights for your downstream task: ```bash Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) from_flax = kwargs.pop("from_flax", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " install accelerate\n```\n." ) if device_map is not None and not is_accelerate_available(): raise NotImplementedError( "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" " `device_map=None`. You can install accelerate with `pip install accelerate`." ) # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) if low_cpu_mem_usage is False and device_map is not None: raise ValueError( f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path user_agent = { "diffusers": __version__, "file_type": "model", "framework": "pytorch", } # load config config, unused_kwargs, commit_hash = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, return_commit_hash=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, device_map=device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, user_agent=user_agent, **kwargs, ) # load model model_file = None if from_flax: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=FLAX_WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) model = cls.from_config(config, **unused_kwargs) # Convert the weights from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model model = load_flax_checkpoint_in_pytorch_model(model, model_file) else: if use_safetensors: try: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) except IOError as e: if not allow_pickle: raise e pass if model_file is None: model_file = _get_model_file( pretrained_model_name_or_path, weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, ) if low_cpu_mem_usage: # Instantiate model with empty weights with accelerate.init_empty_weights(): model = cls.from_config(config, **unused_kwargs) # if device_map is None, load the state dict and move the params from meta device to the cpu if device_map is None: param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) if len(missing_keys) > 0: raise ValueError( f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " those weights or else make sure your checkpoint file is correct." ) unexpected_keys = [] empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): accepts_dtype = "dtype" in set( inspect.signature(set_module_tensor_to_device).parameters.keys() ) if param_name not in empty_state_dict: unexpected_keys.append(param_name) continue if empty_state_dict[param_name].shape != param.shape: raise ValueError( f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." ) if accepts_dtype: set_module_tensor_to_device( model, param_name, param_device, value=param, dtype=torch_dtype ) else: set_module_tensor_to_device(model, param_name, param_device, value=param) if cls._keys_to_ignore_on_load_unexpected is not None: for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: logger.warn( f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" ) else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU try: accelerate.load_checkpoint_and_dispatch( model, model_file, device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, ) except AttributeError as e: # When using accelerate loading, we do not have the ability to load the state # dict and rename the weight names manually. Additionally, accelerate skips # torch loading conventions and directly writes into `module.{_buffers, _parameters}` # (which look like they should be private variables?), so we can't use the standard hooks # to rename parameters on load. We need to mimic the original weight names so the correct # attributes are available. After we have loaded the weights, we convert the deprecated # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert # the weights so we don't have to do this again. if "'Attention' object has no attribute" in str(e): logger.warn( f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" " was saved with deprecated attention block weight names. We will load it with the deprecated attention block" " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," " please also re-upload it or open a PR on the original repository." ) model._temp_convert_self_to_deprecated_attention_blocks() accelerate.load_checkpoint_and_dispatch( model, model_file, device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, ) model._undo_temp_convert_self_to_deprecated_attention_blocks() else: raise e loading_info = { "missing_keys": [], "unexpected_keys": [], "mismatched_keys": [], "error_msgs": [], } else: model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, state_dict, model_file, pretrained_model_name_or_path, ignore_mismatched_sizes=ignore_mismatched_sizes, ) loading_info = { "missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "mismatched_keys": mismatched_keys, "error_msgs": error_msgs, } if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) elif torch_dtype is not None: model = model.to(torch_dtype) model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: return model, loading_info return model @classmethod def _load_pretrained_model( cls, model, state_dict, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes=False, ): # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() loaded_keys = list(state_dict.keys()) expected_keys = list(model_state_dict.keys()) original_loaded_keys = loaded_keys missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) # Make sure we are able to load base models as well as derived models (with heads) model_to_load = model def _find_mismatched_keys( state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: model_key = checkpoint_key if ( model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape ): mismatched_keys.append( (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) ) del state_dict[checkpoint_key] return mismatched_keys if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, ) error_msgs = _load_state_dict_into_model(model_to_load, state_dict) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: error_msg += ( "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task" " or with another architecture (e.g. initializing a BertForSequenceClassification model from a" " BertForPreTraining model).\n- This IS NOT expected if you are initializing" f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly" " identical (initializing a BertForSequenceClassification model from a" " BertForSequenceClassification model)." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the" f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions" " without further training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be" " able to use it for predictions and inference." ) return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs @property def device(self) -> device: """ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). """ return get_parameter_device(self) @property def dtype(self) -> torch.dtype: """ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ return get_parameter_dtype(self) def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: """ Get number of (trainable or non-embedding) parameters in the module. Args: only_trainable (`bool`, *optional*, defaults to `False`): Whether or not to return only the number of trainable parameters. exclude_embeddings (`bool`, *optional*, defaults to `False`): Whether or not to return only the number of non-embedding parameters. Returns: `int`: The number of parameters. Example: ```py from diffusers import UNet2DConditionModel model_id = "runwayml/stable-diffusion-v1-5" unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") unet.num_parameters(only_trainable=True) 859520964 ``` """ if exclude_embeddings: embedding_param_names = [ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, torch.nn.Embedding) ] non_embedding_parameters = [ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names ] return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) def _convert_deprecated_attention_blocks(self, state_dict): deprecated_attention_block_paths = [] def recursive_find_attn_block(name, module): if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: deprecated_attention_block_paths.append(name) for sub_name, sub_module in module.named_children(): sub_name = sub_name if name == "" else f"{name}.{sub_name}" recursive_find_attn_block(sub_name, sub_module) recursive_find_attn_block("", self) # NOTE: we have to check if the deprecated parameters are in the state dict # because it is possible we are loading from a state dict that was already # converted for path in deprecated_attention_block_paths: # group_norm path stays the same # query -> to_q if f"{path}.query.weight" in state_dict: state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") if f"{path}.query.bias" in state_dict: state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") # key -> to_k if f"{path}.key.weight" in state_dict: state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") if f"{path}.key.bias" in state_dict: state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") # value -> to_v if f"{path}.value.weight" in state_dict: state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") if f"{path}.value.bias" in state_dict: state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") # proj_attn -> to_out.0 if f"{path}.proj_attn.weight" in state_dict: state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") if f"{path}.proj_attn.bias" in state_dict: state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") def _temp_convert_self_to_deprecated_attention_blocks(self): deprecated_attention_block_modules = [] def recursive_find_attn_block(module): if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: deprecated_attention_block_modules.append(module) for sub_module in module.children(): recursive_find_attn_block(sub_module) recursive_find_attn_block(self) for module in deprecated_attention_block_modules: module.query = module.to_q module.key = module.to_k module.value = module.to_v module.proj_attn = module.to_out[0] # We don't _have_ to delete the old attributes, but it's helpful to ensure # that _all_ the weights are loaded into the new attributes and we're not # making an incorrect assumption that this model should be converted when # it really shouldn't be. del module.to_q del module.to_k del module.to_v del module.to_out def _undo_temp_convert_self_to_deprecated_attention_blocks(self): deprecated_attention_block_modules = [] def recursive_find_attn_block(module): if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: deprecated_attention_block_modules.append(module) for sub_module in module.children(): recursive_find_attn_block(sub_module) recursive_find_attn_block(self) for module in deprecated_attention_block_modules: module.to_q = module.query module.to_k = module.key module.to_v = module.value module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)]) del module.query del module.key del module.value del module.proj_attn ================================================ FILE: 6DoF/diffusers/models/prior_transformer.py ================================================ from dataclasses import dataclass from typing import Dict, Optional, Union import torch import torch.nn.functional as F from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .attention import BasicTransformerBlock from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin @dataclass class PriorTransformerOutput(BaseOutput): """ The output of [`PriorTransformer`]. Args: predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): The predicted CLIP image embedding conditioned on the CLIP text embedding input. """ predicted_image_embedding: torch.FloatTensor class PriorTransformer(ModelMixin, ConfigMixin): """ A Prior Transformer model. Parameters: num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` num_embeddings (`int`, *optional*, defaults to 77): The number of embeddings of the model input `hidden_states` additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + additional_embeddings`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. time_embed_act_fn (`str`, *optional*, defaults to 'silu'): The activation function to use to create timestep embeddings. norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before passing to Transformer blocks. Set it to `None` if normalization is not needed. embedding_proj_norm_type (`str`, *optional*, defaults to None): The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not needed. encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if `encoder_hidden_states` is `None`. added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot product between the text embedding and image embedding as proposed in the unclip paper https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. If None, will be set to `num_attention_heads * attention_head_dim` embedding_proj_dim (`int`, *optional*, default to None): The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. clip_embed_dim (`int`, *optional*, default to None): The dimension of the output. If None, will be set to `embedding_dim`. """ @register_to_config def __init__( self, num_attention_heads: int = 32, attention_head_dim: int = 64, num_layers: int = 20, embedding_dim: int = 768, num_embeddings=77, additional_embeddings=4, dropout: float = 0.0, time_embed_act_fn: str = "silu", norm_in_type: Optional[str] = None, # layer embedding_proj_norm_type: Optional[str] = None, # layer encoder_hid_proj_type: Optional[str] = "linear", # linear added_emb_type: Optional[str] = "prd", # prd time_embed_dim: Optional[int] = None, embedding_proj_dim: Optional[int] = None, clip_embed_dim: Optional[int] = None, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.additional_embeddings = additional_embeddings time_embed_dim = time_embed_dim or inner_dim embedding_proj_dim = embedding_proj_dim or embedding_dim clip_embed_dim = clip_embed_dim or embedding_dim self.time_proj = Timesteps(inner_dim, True, 0) self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) self.proj_in = nn.Linear(embedding_dim, inner_dim) if embedding_proj_norm_type is None: self.embedding_proj_norm = None elif embedding_proj_norm_type == "layer": self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) else: raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) if encoder_hid_proj_type is None: self.encoder_hidden_states_proj = None elif encoder_hid_proj_type == "linear": self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) else: raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) if added_emb_type == "prd": self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) elif added_emb_type is None: self.prd_embedding = None else: raise ValueError( f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." ) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, activation_fn="gelu", attention_bias=True, ) for d in range(num_layers) ] ) if norm_in_type == "layer": self.norm_in = nn.LayerNorm(inner_dim) elif norm_in_type is None: self.norm_in = None else: raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") self.norm_out = nn.LayerNorm(inner_dim) self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) causal_attention_mask = torch.full( [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 ) causal_attention_mask.triu_(1) causal_attention_mask = causal_attention_mask[None, ...] self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def forward( self, hidden_states, timestep: Union[torch.Tensor, float, int], proj_embedding: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, ): """ The [`PriorTransformer`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): The currently predicted image embeddings. timestep (`torch.LongTensor`): Current denoising step. proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): Projected embedding vector the denoising process is conditioned on. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`): Hidden states of the text embeddings the denoising process is conditioned on. attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): Text mask for the text embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain tuple. Returns: [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ batch_size = hidden_states.shape[0] timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device) timesteps_projected = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might be fp16, so we need to cast here. timesteps_projected = timesteps_projected.to(dtype=self.dtype) time_embeddings = self.time_embedding(timesteps_projected) if self.embedding_proj_norm is not None: proj_embedding = self.embedding_proj_norm(proj_embedding) proj_embeddings = self.embedding_proj(proj_embedding) if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") hidden_states = self.proj_in(hidden_states) positional_embeddings = self.positional_embedding.to(hidden_states.dtype) additional_embeds = [] additional_embeddings_len = 0 if encoder_hidden_states is not None: additional_embeds.append(encoder_hidden_states) additional_embeddings_len += encoder_hidden_states.shape[1] if len(proj_embeddings.shape) == 2: proj_embeddings = proj_embeddings[:, None, :] if len(hidden_states.shape) == 2: hidden_states = hidden_states[:, None, :] additional_embeds = additional_embeds + [ proj_embeddings, time_embeddings[:, None, :], hidden_states, ] if self.prd_embedding is not None: prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) additional_embeds.append(prd_embedding) hidden_states = torch.cat( additional_embeds, dim=1, ) # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 if positional_embeddings.shape[1] < hidden_states.shape[1]: positional_embeddings = F.pad( positional_embeddings, ( 0, 0, additional_embeddings_len, self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, ), value=0.0, ) hidden_states = hidden_states + positional_embeddings if attention_mask is not None: attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0) attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) if self.norm_in is not None: hidden_states = self.norm_in(hidden_states) for block in self.transformer_blocks: hidden_states = block(hidden_states, attention_mask=attention_mask) hidden_states = self.norm_out(hidden_states) if self.prd_embedding is not None: hidden_states = hidden_states[:, -1] else: hidden_states = hidden_states[:, additional_embeddings_len:] predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) if not return_dict: return (predicted_image_embedding,) return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) def post_process_latents(self, prior_latents): prior_latents = (prior_latents * self.clip_std) + self.clip_mean return prior_latents ================================================ FILE: 6DoF/diffusers/models/resnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import SpatialNorm class Upsample1D(nn.Module): """A 1D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name self.conv = None if use_conv_transpose: self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) elif use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) def forward(self, inputs): assert inputs.shape[1] == self.channels if self.use_conv_transpose: return self.conv(inputs) outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") if self.use_conv: outputs = self.conv(outputs) return outputs class Downsample1D(nn.Module): """A 1D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) def forward(self, inputs): assert inputs.shape[1] == self.channels return self.conv(inputs) class Upsample2D(nn.Module): """A 2D upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. use_conv_transpose (`bool`, default `False`): option to use a convolution transpose. out_channels (`int`, optional): number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name conv = None if use_conv_transpose: conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) elif use_conv: conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.conv = conv else: self.Conv2d_0 = conv def forward(self, hidden_states, output_size=None): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: return self.conv(hidden_states) # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch # https://github.com/pytorch/pytorch/issues/86679 dtype = hidden_states.dtype if dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: hidden_states = hidden_states.contiguous() # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16: hidden_states = hidden_states.to(dtype) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": hidden_states = self.conv(hidden_states) else: hidden_states = self.Conv2d_0(hidden_states) return hidden_states class Downsample2D(nn.Module): """A 2D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.Conv2d_0 = conv self.conv = conv elif name == "Conv2d_0": self.conv = conv else: self.conv = conv def forward(self, hidden_states): assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states) return hidden_states class FirUpsample2D(nn.Module): """A 2D FIR upsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. fir_kernel (`tuple`, default `(1, 3, 3, 1)`): kernel for the FIR filter. """ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.use_conv = use_conv self.fir_kernel = fir_kernel self.out_channels = out_channels def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `Conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. weight: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as `hidden_states`. """ assert isinstance(factor, int) and factor >= 1 # Setup filter kernel. if kernel is None: kernel = [1] * factor # setup kernel kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) if self.use_conv: convH = weight.shape[2] convW = weight.shape[3] inC = weight.shape[1] pad_value = (kernel.shape[0] - factor) - (convW - 1) stride = (factor, factor) # Determine data dimensions. output_shape = ( (hidden_states.shape[2] - 1) * factor + convH, (hidden_states.shape[3] - 1) * factor + convW, ) output_padding = ( output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, ) assert output_padding[0] >= 0 and output_padding[1] >= 0 num_groups = hidden_states.shape[1] // inC # Transpose weights. weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) inverse_conv = F.conv_transpose2d( hidden_states, weight, stride=stride, output_padding=output_padding, padding=0 ) output = upfirdn2d_native( inverse_conv, torch.tensor(kernel, device=inverse_conv.device), pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), ) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output def forward(self, hidden_states): if self.use_conv: height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return height class FirDownsample2D(nn.Module): """A 2D FIR downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. fir_kernel (`tuple`, default `(1, 3, 3, 1)`): kernel for the FIR filter. """ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.fir_kernel = fir_kernel self.use_conv = use_conv self.out_channels = out_channels def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): """Fused `Conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. weight: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor # setup kernel kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * gain if self.use_conv: _, _, convH, convW = weight.shape pad_value = (kernel.shape[0] - factor) + (convW - 1) stride_value = [factor, factor] upfirdn_input = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), pad=((pad_value + 1) // 2, pad_value // 2), ) output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) return output def forward(self, hidden_states): if self.use_conv: downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return hidden_states # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead class KDownsample2D(nn.Module): def __init__(self, pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) def forward(self, inputs): inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel return F.conv2d(inputs, weight, stride=2) class KUpsample2D(nn.Module): def __init__(self, pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) def forward(self, inputs): inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode) weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) class ResnetBlock2D(nn.Module): r""" A Resnet block. Parameters: in_channels (`int`): The number of channels in the input. out_channels (`int`, *optional*, default to be `None`): The number of output channels for the first conv2d layer. If None, same as `in_channels`. dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. groups_out (`int`, *optional*, default to None): The number of groups to use for the second normalization layer. if set to None, same as `groups`. eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or "ada_group" for a stronger conditioning with scale and shift. kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. use_in_shortcut (`bool`, *optional*, default to `True`): If `True`, add a 1x1 nn.conv2d layer for skip-connection. up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the `conv_shortcut` output. conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. If None, same as `out_channels`. """ def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, groups_out=None, pre_norm=True, eps=1e-6, non_linearity="swish", skip_time_act=False, time_embedding_norm="default", # default, scale_shift, ada_group, spatial kernel=None, output_scale_factor=1.0, use_in_shortcut=None, up=False, down=False, conv_shortcut_bias: bool = True, conv_2d_out_channels: Optional[int] = None, ): super().__init__() self.pre_norm = pre_norm self.pre_norm = True self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.up = up self.down = down self.output_scale_factor = output_scale_factor self.time_embedding_norm = time_embedding_norm self.skip_time_act = skip_time_act if groups_out is None: groups_out = groups if self.time_embedding_norm == "ada_group": self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) elif self.time_embedding_norm == "spatial": self.norm1 = SpatialNorm(in_channels, temb_channels) else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") else: self.time_emb_proj = None if self.time_embedding_norm == "ada_group": self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) elif self.time_embedding_norm == "spatial": self.norm2 = SpatialNorm(out_channels, temb_channels) else: self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) self.upsample = self.downsample = None if self.up: if kernel == "fir": fir_kernel = (1, 3, 3, 1) self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) elif kernel == "sde_vp": self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") else: self.upsample = Upsample2D(in_channels, use_conv=False) elif self.down: if kernel == "fir": fir_kernel = (1, 3, 3, 1) self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) elif kernel == "sde_vp": self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) else: self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) def forward(self, input_tensor, temb): hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm1(hidden_states, temb) else: hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor) hidden_states = self.upsample(hidden_states) elif self.downsample is not None: input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) hidden_states = self.conv1(hidden_states) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) temb = self.time_emb_proj(temb)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm2(hidden_states, temb) else: hidden_states = self.norm2(hidden_states) if temb is not None and self.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor return output_tensor # unet_rl.py def rearrange_dims(tensor): if len(tensor.shape) == 2: return tensor[:, :, None] if len(tensor.shape) == 3: return tensor[:, :, None, :] elif len(tensor.shape) == 4: return tensor[:, :, 0, :] else: raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish """ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) self.group_norm = nn.GroupNorm(n_groups, out_channels) self.mish = nn.Mish() def forward(self, inputs): intermediate_repr = self.conv1d(inputs) intermediate_repr = rearrange_dims(intermediate_repr) intermediate_repr = self.group_norm(intermediate_repr) intermediate_repr = rearrange_dims(intermediate_repr) output = self.mish(intermediate_repr) return output # unet_rl.py class ResidualTemporalBlock1D(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): super().__init__() self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) self.time_emb_act = nn.Mish() self.time_emb = nn.Linear(embed_dim, out_channels) self.residual_conv = ( nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() ) def forward(self, inputs, t): """ Args: inputs : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x out_channels x horizon ] """ t = self.time_emb_act(t) t = self.time_emb(t) out = self.conv_in(inputs) + rearrange_dims(t) out = self.conv_out(out) return out + self.residual_conv(inputs) def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: multiple of the upsampling factor. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H * factor, W * factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), ) return output def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Downsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the downsampling factor. Args: hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: output: Tensor of the shape `[N, C, H // factor, W // factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * gain pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) ) return output def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): up_x = up_y = up down_x = down_y = down pad_x0 = pad_y0 = pad[0] pad_x1 = pad_y1 = pad[1] _, channel, in_h, in_w = tensor.shape tensor = tensor.reshape(-1, in_h, in_w, 1) _, in_h, in_w, minor = tensor.shape kernel_h, kernel_w = kernel.shape out = tensor.view(-1, in_h, 1, in_w, 1, minor) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) out = out.to(tensor.device) # Move back to mps if necessary out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) out = out[:, ::down_y, ::down_x, :] out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) class TemporalConvLayer(nn.Module): """ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 """ def __init__(self, in_dim, out_dim=None, dropout=0.0): super().__init__() out_dim = out_dim or in_dim self.in_dim = in_dim self.out_dim = out_dim # conv layers self.conv1 = nn.Sequential( nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)) ) self.conv2 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv3 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) self.conv4 = nn.Sequential( nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), ) # zero out the last layer params,so the conv block is identity nn.init.zeros_(self.conv4[-1].weight) nn.init.zeros_(self.conv4[-1].bias) def forward(self, hidden_states, num_frames=1): hidden_states = ( hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4) ) identity = hidden_states hidden_states = self.conv1(hidden_states) hidden_states = self.conv2(hidden_states) hidden_states = self.conv3(hidden_states) hidden_states = self.conv4(hidden_states) hidden_states = identity + hidden_states hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape( (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:] ) return hidden_states ================================================ FILE: 6DoF/diffusers/models/resnet_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import flax.linen as nn import jax import jax.numpy as jnp class FlaxUpsample2D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, hidden_states): batch, height, width, channels = hidden_states.shape hidden_states = jax.image.resize( hidden_states, shape=(batch, height * 2, width * 2, channels), method="nearest", ) hidden_states = self.conv(hidden_states) return hidden_states class FlaxDownsample2D(nn.Module): out_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)), # padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states): # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim # hidden_states = jnp.pad(hidden_states, pad_width=pad) hidden_states = self.conv(hidden_states) return hidden_states class FlaxResnetBlock2D(nn.Module): in_channels: int out_channels: int = None dropout_prob: float = 0.0 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.dropout = nn.Dropout(self.dropout_prob) self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut self.conv_shortcut = None if use_nin_shortcut: self.conv_shortcut = nn.Conv( out_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states, temb, deterministic=True): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) temb = self.time_emb_proj(nn.swish(temb)) temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.dropout(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return hidden_states + residual ================================================ FILE: 6DoF/diffusers/models/t5_film_transformer.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from .attention_processor import Attention from .embeddings import get_timestep_embedding from .modeling_utils import ModelMixin class T5FilmDecoder(ModelMixin, ConfigMixin): @register_to_config def __init__( self, input_dims: int = 128, targets_length: int = 256, max_decoder_noise_time: float = 2000.0, d_model: int = 768, num_layers: int = 12, num_heads: int = 12, d_kv: int = 64, d_ff: int = 2048, dropout_rate: float = 0.1, ): super().__init__() self.conditioning_emb = nn.Sequential( nn.Linear(d_model, d_model * 4, bias=False), nn.SiLU(), nn.Linear(d_model * 4, d_model * 4, bias=False), nn.SiLU(), ) self.position_encoding = nn.Embedding(targets_length, d_model) self.position_encoding.weight.requires_grad = False self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) self.dropout = nn.Dropout(p=dropout_rate) self.decoders = nn.ModuleList() for lyr_num in range(num_layers): # FiLM conditional T5 decoder lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) self.decoders.append(lyr) self.decoder_norm = T5LayerNorm(d_model) self.post_dropout = nn.Dropout(p=dropout_rate) self.spec_out = nn.Linear(d_model, input_dims, bias=False) def encoder_decoder_mask(self, query_input, key_input): mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) return mask.unsqueeze(-3) def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): batch, _, _ = decoder_input_tokens.shape assert decoder_noise_time.shape == (batch,) # decoder_noise_time is in [0, 1), so rescale to expected timing range. time_steps = get_timestep_embedding( decoder_noise_time * self.config.max_decoder_noise_time, embedding_dim=self.config.d_model, max_period=self.config.max_decoder_noise_time, ).to(dtype=self.dtype) conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) seq_length = decoder_input_tokens.shape[1] # If we want to use relative positions for audio context, we can just offset # this sequence by the length of encodings_and_masks. decoder_positions = torch.broadcast_to( torch.arange(seq_length, device=decoder_input_tokens.device), (batch, seq_length), ) position_encodings = self.position_encoding(decoder_positions) inputs = self.continuous_inputs_projection(decoder_input_tokens) inputs += position_encodings y = self.dropout(inputs) # decoder: No padding present. decoder_mask = torch.ones( decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype ) # Translate encoding masks to encoder-decoder masks. encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] # cross attend style: concat encodings encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) for lyr in self.decoders: y = lyr( y, conditioning_emb=conditioning_emb, encoder_hidden_states=encoded, encoder_attention_mask=encoder_decoder_mask, )[0] y = self.decoder_norm(y) y = self.post_dropout(y) spec_out = self.spec_out(y) return spec_out class DecoderLayer(nn.Module): def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6): super().__init__() self.layer = nn.ModuleList() # cond self attention: layer 0 self.layer.append( T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) ) # cross attention: layer 1 self.layer.append( T5LayerCrossAttention( d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon, ) ) # Film Cond MLP + dropout: last layer self.layer.append( T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) ) def forward( self, hidden_states, conditioning_emb=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, ): hidden_states = self.layer[0]( hidden_states, conditioning_emb=conditioning_emb, attention_mask=attention_mask, ) if encoder_hidden_states is not None: encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( encoder_hidden_states.dtype ) hidden_states = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_extended_attention_mask, ) # Apply Film Conditional Feed Forward layer hidden_states = self.layer[-1](hidden_states, conditioning_emb) return (hidden_states,) class T5LayerSelfAttentionCond(nn.Module): def __init__(self, d_model, d_kv, num_heads, dropout_rate): super().__init__() self.layer_norm = T5LayerNorm(d_model) self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) self.dropout = nn.Dropout(dropout_rate) def forward( self, hidden_states, conditioning_emb=None, attention_mask=None, ): # pre_self_attention_layer_norm normed_hidden_states = self.layer_norm(hidden_states) if conditioning_emb is not None: normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) # Self-attention block attention_output = self.attention(normed_hidden_states) hidden_states = hidden_states + self.dropout(attention_output) return hidden_states class T5LayerCrossAttention(nn.Module): def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon): super().__init__() self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.dropout = nn.Dropout(dropout_rate) def forward( self, hidden_states, key_value_states=None, attention_mask=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( normed_hidden_states, encoder_hidden_states=key_value_states, attention_mask=attention_mask.squeeze(1), ) layer_output = hidden_states + self.dropout(attention_output) return layer_output class T5LayerFFCond(nn.Module): def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon): super().__init__() self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) self.dropout = nn.Dropout(dropout_rate) def forward(self, hidden_states, conditioning_emb=None): forwarded_states = self.layer_norm(hidden_states) if conditioning_emb is not None: forwarded_states = self.film(forwarded_states, conditioning_emb) forwarded_states = self.DenseReluDense(forwarded_states) hidden_states = hidden_states + self.dropout(forwarded_states) return hidden_states class T5DenseGatedActDense(nn.Module): def __init__(self, d_model, d_ff, dropout_rate): super().__init__() self.wi_0 = nn.Linear(d_model, d_ff, bias=False) self.wi_1 = nn.Linear(d_model, d_ff, bias=False) self.wo = nn.Linear(d_ff, d_model, bias=False) self.dropout = nn.Dropout(dropout_rate) self.act = NewGELUActivation() def forward(self, hidden_states): hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) hidden_states = self.wo(hidden_states) return hidden_states class T5LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for # half-precision inputs is done in fp32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states class NewGELUActivation(nn.Module): """ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 """ def forward(self, input: torch.Tensor) -> torch.Tensor: return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) class T5FiLMLayer(nn.Module): """ FiLM Layer """ def __init__(self, in_features, out_features): super().__init__() self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) def forward(self, x, conditioning_emb): emb = self.scale_bias(conditioning_emb) scale, shift = torch.chunk(emb, 2, -1) x = x * (1 + scale) + shift return x ================================================ FILE: 6DoF/diffusers/models/transformer_2d.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..models.embeddings import ImagePositionalEmbeddings from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock from .embeddings import PatchEmbed from .modeling_utils import ModelMixin @dataclass class Transformer2DModelOutput(BaseOutput): """ The output of [`Transformer2DModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. """ sample: torch.FloatTensor class Transformer2DModel(ModelMixin, ConfigMixin): """ A 2D Transformer model for image-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. num_vector_embeds (`int`, *optional*): The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). Includes the class for the masked latent pixel. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. num_embeds_ada_norm ( `int`, *optional*): The number of diffusion steps used during training. Pass if at least one of the norm_layers is `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): Configure if the `TransformerBlocks` attention should contain a bias parameter. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", norm_elementwise_affine: bool = True, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_vectorized = num_vector_embeds is not None self.is_input_patches = in_channels is not None and patch_size is not None if norm_type == "layer_norm" and num_embeds_ada_norm is not None: deprecation_message = ( f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" " would be very nice if you could open a Pull request for the `transformer/config.json` file" ) deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) norm_type = "ada_norm" if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is None." ) elif self.is_input_vectorized and self.is_input_patches: raise ValueError( f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" " sure that either `num_vector_embeds` or `num_patches` is None." ) elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: raise ValueError( f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." ) # 2. Define input layers if self.is_input_continuous: self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: self.proj_in = nn.Linear(in_channels, inner_dim) else: self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" self.height = sample_size self.width = sample_size self.num_vector_embeds = num_vector_embeds self.num_latent_pixels = self.height * self.width self.latent_image_embedding = ImagePositionalEmbeddings( num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width ) elif self.is_input_patches: assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" self.height = sample_size self.width = sample_size self.patch_size = patch_size self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, ) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, norm_elementwise_affine=norm_elementwise_affine, ) for d in range(num_layers) ] ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels if self.is_input_continuous: # TODO: should use out_channels for continuous projections if use_linear_projection: self.proj_out = nn.Linear(inner_dim, in_channels) else: self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) elif self.is_input_patches: self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, timestep: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None, posemb: Optional = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ The [`Transformer2DModel`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): Input `hidden_states`. encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. encoder_attention_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batch, sequence_length)` True = keep, False = discard. * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: hidden_states = self.pos_embed(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, posemb=posemb, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual elif self.is_input_vectorized: hidden_states = self.norm_out(hidden_states) logits = self.out(hidden_states) # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) logits = logits.permute(0, 2, 1) # log(p(x_0)) output = F.log_softmax(logits.double(), dim=1).float() elif self.is_input_patches: # TODO: cleanup! conditioning = self.transformer_blocks[0].norm1.emb( timestep, class_labels, hidden_dtype=hidden_states.dtype ) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) # unpatchify height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) ================================================ FILE: 6DoF/diffusers/models/transformer_temporal.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional import torch from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .attention import BasicTransformerBlock from .modeling_utils import ModelMixin @dataclass class TransformerTemporalModelOutput(BaseOutput): """ The output of [`TransformerTemporalModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. """ sample: torch.FloatTensor class TransformerTemporalModel(ModelMixin, ConfigMixin): """ A Transformer model for video-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). This is fixed during training since it is used to learn a number of position embeddings. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. attention_bias (`bool`, *optional*): Configure if the `TransformerBlock` attention should contain a bias parameter. double_self_attention (`bool`, *optional*): Configure if each `TransformerBlock` should contain two self-attention layers. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, activation_fn: str = "geglu", norm_elementwise_affine: bool = True, double_self_attention: bool = True, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Linear(in_channels, inner_dim) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, attention_bias=attention_bias, double_self_attention=double_self_attention, norm_elementwise_affine=norm_elementwise_affine, ) for d in range(num_layers) ] ) self.proj_out = nn.Linear(inner_dim, in_channels) def forward( self, hidden_states, encoder_hidden_states=None, timestep=None, class_labels=None, num_frames=1, cross_attention_kwargs=None, return_dict: bool = True, ): """ The [`TransformerTemporal`] forward method. Args: hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): Input hidden_states. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in `AdaLayerZeroNorm`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ # 1. Input batch_frames, channel, height, width = hidden_states.shape batch_size = batch_frames // num_frames residual = hidden_states hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) hidden_states = hidden_states.permute(0, 2, 1, 3, 4) hidden_states = self.norm(hidden_states) hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) hidden_states = self.proj_in(hidden_states) # 2. Blocks for block in self.transformer_blocks: hidden_states = block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states[None, None, :] .reshape(batch_size, height, width, channel, num_frames) .permute(0, 3, 4, 1, 2) .contiguous() ) hidden_states = hidden_states.reshape(batch_frames, channel, height, width) output = hidden_states + residual if not return_dict: return (output,) return TransformerTemporalModelOutput(sample=output) ================================================ FILE: 6DoF/diffusers/models/unet_1d.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @dataclass class UNet1DOutput(BaseOutput): """ The output of [`UNet1DModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): The hidden states output from the last layer of the model. """ sample: torch.FloatTensor class UNet1DModel(ModelMixin, ConfigMixin): r""" A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. extra_in_channels (`int`, *optional*, defaults to 0): Number of additional channels to be added to the input of the first down block. Useful for cases where the input data has more channels than what the model was initially designed for. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sin to cos for Fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): Tuple of block output channels. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. downsample_each_block (`int`, *optional*, defaults to `False`): Experimental feature for using a UNet without upsampling. """ @register_to_config def __init__( self, sample_size: int = 65536, sample_rate: Optional[int] = None, in_channels: int = 2, out_channels: int = 2, extra_in_channels: int = 0, time_embedding_type: str = "fourier", flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, freq_shift: float = 0.0, down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), mid_block_type: Tuple[str] = "UNetMidBlock1D", out_block_type: str = None, block_out_channels: Tuple[int] = (32, 32, 64), act_fn: str = None, norm_num_groups: int = 8, layers_per_block: int = 1, downsample_each_block: bool = False, ): super().__init__() self.sample_size = sample_size # time if time_embedding_type == "fourier": self.time_proj = GaussianFourierProjection( embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps( block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift ) timestep_input_dim = block_out_channels[0] if use_timestep_embedding: time_embed_dim = block_out_channels[0] * 4 self.time_mlp = TimestepEmbedding( in_channels=timestep_input_dim, time_embed_dim=time_embed_dim, act_fn=act_fn, out_dim=block_out_channels[0], ) self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) self.out_block = None # down output_channel = in_channels for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] if i == 0: input_channel += extra_in_channels is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], add_downsample=not is_final_block or downsample_each_block, ) self.down_blocks.append(down_block) # mid self.mid_block = get_mid_block( mid_block_type, in_channels=block_out_channels[-1], mid_channels=block_out_channels[-1], out_channels=block_out_channels[-1], embed_dim=block_out_channels[0], num_layers=layers_per_block, add_downsample=downsample_each_block, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] if out_block_type is None: final_upsample_channels = out_channels else: final_upsample_channels = block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = ( reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels ) is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=layers_per_block, in_channels=prev_output_channel, out_channels=output_channel, temb_channels=block_out_channels[0], add_upsample=not is_final_block, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.out_block = get_out_block( out_block_type=out_block_type, num_groups_out=num_groups_out, embed_dim=block_out_channels[0], out_channels=out_channels, act_fn=act_fn, fc_dim=block_out_channels[-1] // 4, ) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], return_dict: bool = True, ) -> Union[UNet1DOutput, Tuple]: r""" The [`UNet1DModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. Returns: [`~models.unet_1d.UNet1DOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # 1. time timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) timestep_embed = self.time_proj(timesteps) if self.config.use_timestep_embedding: timestep_embed = self.time_mlp(timestep_embed) else: timestep_embed = timestep_embed[..., None] timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) # 2. down down_block_res_samples = () for downsample_block in self.down_blocks: sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) down_block_res_samples += res_samples # 3. mid if self.mid_block: sample = self.mid_block(sample, timestep_embed) # 4. up for i, upsample_block in enumerate(self.up_blocks): res_samples = down_block_res_samples[-1:] down_block_res_samples = down_block_res_samples[:-1] sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) # 5. post-process if self.out_block: sample = self.out_block(sample, timestep_embed) if not return_dict: return (sample,) return UNet1DOutput(sample=sample) ================================================ FILE: 6DoF/diffusers/models/unet_1d_blocks.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import torch.nn.functional as F from torch import nn from .activations import get_activation from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims class DownResnetBlock1D(nn.Module): def __init__( self, in_channels, out_channels=None, num_layers=1, conv_shortcut=False, temb_channels=32, groups=32, groups_out=None, non_linearity=None, time_embedding_norm="default", output_scale_factor=1.0, add_downsample=True, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.time_embedding_norm = time_embedding_norm self.add_downsample = add_downsample self.output_scale_factor = output_scale_factor if groups_out is None: groups_out = groups # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] for _ in range(num_layers): resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) self.resnets = nn.ModuleList(resnets) if non_linearity is None: self.nonlinearity = None else: self.nonlinearity = get_activation(non_linearity) self.downsample = None if add_downsample: self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) def forward(self, hidden_states, temb=None): output_states = () hidden_states = self.resnets[0](hidden_states, temb) for resnet in self.resnets[1:]: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.nonlinearity is not None: hidden_states = self.nonlinearity(hidden_states) if self.downsample is not None: hidden_states = self.downsample(hidden_states) return hidden_states, output_states class UpResnetBlock1D(nn.Module): def __init__( self, in_channels, out_channels=None, num_layers=1, temb_channels=32, groups=32, groups_out=None, non_linearity=None, time_embedding_norm="default", output_scale_factor=1.0, add_upsample=True, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.time_embedding_norm = time_embedding_norm self.add_upsample = add_upsample self.output_scale_factor = output_scale_factor if groups_out is None: groups_out = groups # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] for _ in range(num_layers): resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) self.resnets = nn.ModuleList(resnets) if non_linearity is None: self.nonlinearity = None else: self.nonlinearity = get_activation(non_linearity) self.upsample = None if add_upsample: self.upsample = Upsample1D(out_channels, use_conv_transpose=True) def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): if res_hidden_states_tuple is not None: res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) hidden_states = self.resnets[0](hidden_states, temb) for resnet in self.resnets[1:]: hidden_states = resnet(hidden_states, temb) if self.nonlinearity is not None: hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: hidden_states = self.upsample(hidden_states) return hidden_states class ValueFunctionMidBlock1D(nn.Module): def __init__(self, in_channels, out_channels, embed_dim): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.embed_dim = embed_dim self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) self.down1 = Downsample1D(out_channels // 2, use_conv=True) self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) self.down2 = Downsample1D(out_channels // 4, use_conv=True) def forward(self, x, temb=None): x = self.res1(x, temb) x = self.down1(x) x = self.res2(x, temb) x = self.down2(x) return x class MidResTemporalBlock1D(nn.Module): def __init__( self, in_channels, out_channels, embed_dim, num_layers: int = 1, add_downsample: bool = False, add_upsample: bool = False, non_linearity=None, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.add_downsample = add_downsample # there will always be at least one resnet resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] for _ in range(num_layers): resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) self.resnets = nn.ModuleList(resnets) if non_linearity is None: self.nonlinearity = None else: self.nonlinearity = get_activation(non_linearity) self.upsample = None if add_upsample: self.upsample = Downsample1D(out_channels, use_conv=True) self.downsample = None if add_downsample: self.downsample = Downsample1D(out_channels, use_conv=True) if self.upsample and self.downsample: raise ValueError("Block cannot downsample and upsample") def forward(self, hidden_states, temb): hidden_states = self.resnets[0](hidden_states, temb) for resnet in self.resnets[1:]: hidden_states = resnet(hidden_states, temb) if self.upsample: hidden_states = self.upsample(hidden_states) if self.downsample: self.downsample = self.downsample(hidden_states) return hidden_states class OutConv1DBlock(nn.Module): def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): super().__init__() self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) self.final_conv1d_act = get_activation(act_fn) self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) def forward(self, hidden_states, temb=None): hidden_states = self.final_conv1d_1(hidden_states) hidden_states = rearrange_dims(hidden_states) hidden_states = self.final_conv1d_gn(hidden_states) hidden_states = rearrange_dims(hidden_states) hidden_states = self.final_conv1d_act(hidden_states) hidden_states = self.final_conv1d_2(hidden_states) return hidden_states class OutValueFunctionBlock(nn.Module): def __init__(self, fc_dim, embed_dim): super().__init__() self.final_block = nn.ModuleList( [ nn.Linear(fc_dim + embed_dim, fc_dim // 2), nn.Mish(), nn.Linear(fc_dim // 2, 1), ] ) def forward(self, hidden_states, temb): hidden_states = hidden_states.view(hidden_states.shape[0], -1) hidden_states = torch.cat((hidden_states, temb), dim=-1) for layer in self.final_block: hidden_states = layer(hidden_states) return hidden_states _kernels = { "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], "lanczos3": [ 0.003689131001010537, 0.015056144446134567, -0.03399861603975296, -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 0.44638532400131226, 0.13550527393817902, -0.066637322306633, -0.03399861603975296, 0.015056144446134567, 0.003689131001010537, ], } class Downsample1d(nn.Module): def __init__(self, kernel="linear", pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) def forward(self, hidden_states): hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) weight[indices, indices] = kernel return F.conv1d(hidden_states, weight, stride=2) class Upsample1d(nn.Module): def __init__(self, kernel="linear", pad_mode="reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor(_kernels[kernel]) * 2 self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) def forward(self, hidden_states, temb=None): hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) weight[indices, indices] = kernel return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) class SelfAttention1d(nn.Module): def __init__(self, in_channels, n_head=1, dropout_rate=0.0): super().__init__() self.channels = in_channels self.group_norm = nn.GroupNorm(1, num_channels=in_channels) self.num_heads = n_head self.query = nn.Linear(self.channels, self.channels) self.key = nn.Linear(self.channels, self.channels) self.value = nn.Linear(self.channels, self.channels) self.proj_attn = nn.Linear(self.channels, self.channels, bias=True) self.dropout = nn.Dropout(dropout_rate, inplace=True) def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) return new_projection def forward(self, hidden_states): residual = hidden_states batch, channel_dim, seq = hidden_states.shape hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.transpose(1, 2) query_proj = self.query(hidden_states) key_proj = self.key(hidden_states) value_proj = self.value(hidden_states) query_states = self.transpose_for_scores(query_proj) key_states = self.transpose_for_scores(key_proj) value_states = self.transpose_for_scores(value_proj) scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) attention_probs = torch.softmax(attention_scores, dim=-1) # compute attention output hidden_states = torch.matmul(attention_probs, value_states) hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) hidden_states = hidden_states.view(new_hidden_states_shape) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.transpose(1, 2) hidden_states = self.dropout(hidden_states) output = hidden_states + residual return output class ResConvBlock(nn.Module): def __init__(self, in_channels, mid_channels, out_channels, is_last=False): super().__init__() self.is_last = is_last self.has_conv_skip = in_channels != out_channels if self.has_conv_skip: self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False) self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2) self.group_norm_1 = nn.GroupNorm(1, mid_channels) self.gelu_1 = nn.GELU() self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2) if not self.is_last: self.group_norm_2 = nn.GroupNorm(1, out_channels) self.gelu_2 = nn.GELU() def forward(self, hidden_states): residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states hidden_states = self.conv_1(hidden_states) hidden_states = self.group_norm_1(hidden_states) hidden_states = self.gelu_1(hidden_states) hidden_states = self.conv_2(hidden_states) if not self.is_last: hidden_states = self.group_norm_2(hidden_states) hidden_states = self.gelu_2(hidden_states) output = hidden_states + residual return output class UNetMidBlock1D(nn.Module): def __init__(self, mid_channels, in_channels, out_channels=None): super().__init__() out_channels = in_channels if out_channels is None else out_channels # there is always at least one resnet self.down = Downsample1d("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(out_channels, out_channels // 32), ] self.up = Upsample1d(kernel="cubic") self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = self.down(hidden_states) for attn, resnet in zip(self.attentions, self.resnets): hidden_states = resnet(hidden_states) hidden_states = attn(hidden_states) hidden_states = self.up(hidden_states) return hidden_states class AttnDownBlock1D(nn.Module): def __init__(self, out_channels, in_channels, mid_channels=None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels self.down = Downsample1d("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(out_channels, out_channels // 32), ] self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = self.down(hidden_states) for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states) hidden_states = attn(hidden_states) return hidden_states, (hidden_states,) class DownBlock1D(nn.Module): def __init__(self, out_channels, in_channels, mid_channels=None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels self.down = Downsample1d("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = self.down(hidden_states) for resnet in self.resnets: hidden_states = resnet(hidden_states) return hidden_states, (hidden_states,) class DownBlock1DNoSkip(nn.Module): def __init__(self, out_channels, in_channels, mid_channels=None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = torch.cat([hidden_states, temb], dim=1) for resnet in self.resnets: hidden_states = resnet(hidden_states) return hidden_states, (hidden_states,) class AttnUpBlock1D(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() mid_channels = out_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(mid_channels, mid_channels // 32), SelfAttention1d(out_channels, out_channels // 32), ] self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states) hidden_states = attn(hidden_states) hidden_states = self.up(hidden_states) return hidden_states class UpBlock1D(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() mid_channels = in_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels), ] self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) for resnet in self.resnets: hidden_states = resnet(hidden_states) hidden_states = self.up(hidden_states) return hidden_states class UpBlock1DNoSkip(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() mid_channels = in_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, mid_channels), ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), ] self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) for resnet in self.resnets: hidden_states = resnet(hidden_states) return hidden_states def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): if down_block_type == "DownResnetBlock1D": return DownResnetBlock1D( in_channels=in_channels, num_layers=num_layers, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, ) elif down_block_type == "DownBlock1D": return DownBlock1D(out_channels=out_channels, in_channels=in_channels) elif down_block_type == "AttnDownBlock1D": return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) elif down_block_type == "DownBlock1DNoSkip": return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) raise ValueError(f"{down_block_type} does not exist.") def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): if up_block_type == "UpResnetBlock1D": return UpResnetBlock1D( in_channels=in_channels, num_layers=num_layers, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, ) elif up_block_type == "UpBlock1D": return UpBlock1D(in_channels=in_channels, out_channels=out_channels) elif up_block_type == "AttnUpBlock1D": return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) elif up_block_type == "UpBlock1DNoSkip": return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) raise ValueError(f"{up_block_type} does not exist.") def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): if mid_block_type == "MidResTemporalBlock1D": return MidResTemporalBlock1D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim, add_downsample=add_downsample, ) elif mid_block_type == "ValueFunctionMidBlock1D": return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) elif mid_block_type == "UNetMidBlock1D": return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) raise ValueError(f"{mid_block_type} does not exist.") def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): if out_block_type == "OutConv1DBlock": return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) elif out_block_type == "ValueFunction": return OutValueFunctionBlock(fc_dim, embed_dim) return None ================================================ FILE: 6DoF/diffusers/models/unet_2d.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass class UNet2DOutput(BaseOutput): """ The output of [`UNet2DModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output from the last layer of the model. """ sample: torch.FloatTensor class UNet2DModel(ModelMixin, ConfigMixin): r""" A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - 1)`. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip sin to cos for Fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block types. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): Tuple of block output channels. layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. downsample_type (`str`, *optional*, defaults to `conv`): The downsample type for downsampling layers. Choose between "conv" and "resnet" upsample_type (`str`, *optional*, defaults to `conv`): The upsample type for upsampling layers. Choose between "conv" and "resnet" act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class conditioning with `class_embed_type` equal to `None`. """ @register_to_config def __init__( self, sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 3, out_channels: int = 3, center_input_sample: bool = False, time_embedding_type: str = "positional", freq_shift: int = 0, flip_sin_to_cos: bool = True, down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), block_out_channels: Tuple[int] = (224, 448, 672, 896), layers_per_block: int = 2, mid_block_scale_factor: float = 1, downsample_padding: int = 1, downsample_type: str = "conv", upsample_type: str = "conv", act_fn: str = "silu", attention_head_dim: Optional[int] = 8, norm_num_groups: int = 32, norm_eps: float = 1e-5, resnet_time_scale_shift: str = "default", add_attention: bool = True, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, ): super().__init__() self.sample_size = sample_size time_embed_dim = block_out_channels[0] * 4 # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) # input self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) # time if time_embedding_type == "fourier": self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) else: self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], resnet_groups=norm_num_groups, add_attention=add_attention, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DOutput, Tuple]: r""" The [`UNet2DModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. Returns: [`~models.unet_2d.UNet2DOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when doing class conditioning") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # 2. pre-process skip_sample = sample sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "skip_conv"): sample, res_samples, skip_sample = downsample_block( hidden_states=sample, temb=emb, skip_sample=skip_sample ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid sample = self.mid_block(sample, emb) # 5. up skip_sample = None for upsample_block in self.up_blocks: res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] if hasattr(upsample_block, "skip_conv"): sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample) else: sample = upsample_block(sample, res_samples, emb) # 6. post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if skip_sample is not None: sample += skip_sample if self.config.time_embedding_type == "fourier": timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) sample = sample / timesteps if not return_dict: return (sample,) return UNet2DOutput(sample=sample) ================================================ FILE: 6DoF/diffusers/models/unet_2d_blocks.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn from ..utils import is_torch_version, logging from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, downsample_type=None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warn( f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": return DownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "ResnetDownsampleBlock2D": return ResnetDownsampleBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "AttnDownBlock2D": if add_downsample is False: downsample_type = None else: downsample_type = downsample_type or "conv" # default to 'conv' return AttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "SimpleCrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") return SimpleCrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnSkipDownBlock2D": return AttnSkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnDownEncoderBlock2D": return AttnDownEncoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "KDownBlock2D": return KDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) elif down_block_type == "KCrossAttnDownBlock2D": return KCrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, add_self_attention=True if not add_downsample else False, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, upsample_type=None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warn( f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": return UpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "ResnetUpsampleBlock2D": return ResnetUpsampleBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "SimpleCrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") return SimpleCrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": if add_upsample is False: upsample_type = None else: upsample_type = upsample_type or "conv" # default to 'conv' return AttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "AttnSkipUpBlock2D": return AttnSkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) elif up_block_type == "AttnUpDecoderBlock2D": return AttnUpDecoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) elif up_block_type == "KCrossAttnUpBlock2D": return KCrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, ) raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock2D(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim=1, output_scale_factor=1.0, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." ) attention_head_dim = in_channels for _ in range(num_layers): if self.add_attention: attentions.append( Attention( in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) else: attentions.append(None) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: hidden_states = attn(hidden_states, temb=temb) hidden_states = resnet(hidden_states, temb) return hidden_states class UNetMidBlock2DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] for _ in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, posemb: Optional = None, ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, posemb=posemb, )[0] hidden_states = resnet(hidden_states, temb) return hidden_states class UNetMidBlock2DSimpleCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() self.has_cross_attention = True self.attention_head_dim = attention_head_dim resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.num_heads = in_channels // self.attention_head_dim # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ] attentions = [] for _ in range(num_layers): processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=in_channels, cross_attention_dim=in_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) # resnet hidden_states = resnet(hidden_states, temb) return hidden_states class AttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, downsample_padding=1, downsample_type="conv", ): super().__init__() resnets = [] attentions = [] self.downsample_type = downsample_type if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if downsample_type == "conv": self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) elif downsample_type == "resnet": self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, down=True, ) ] ) else: self.downsamplers = None def forward(self, hidden_states, temb=None, upsample_size=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: if self.downsample_type == "resnet": hidden_states = downsampler(hidden_states, temb=temb) else: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, posemb: Optional = None, ): output_states = () for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), # transformer_2d hidden_states, encoder_hidden_states, None, # timestep None, # class_labels posemb, cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, posemb=posemb, )[0] output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class DownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class DownEncoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def forward(self, hidden_states): for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class AttnDownEncoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] attentions = [] if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def forward(self, hidden_states): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=None) hidden_states = attn(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class AttnSkipDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=np.sqrt(2.0), add_downsample=True, ): super().__init__() self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(in_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=32, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) if add_downsample: self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, down=True, kernel="fir", ) self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None self.downsamplers = None self.skip_conv = None def forward(self, hidden_states, temb=None, skip_sample=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) output_states += (hidden_states,) if self.downsamplers is not None: hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) hidden_states = self.skip_conv(skip_sample) + hidden_states output_states += (hidden_states,) return hidden_states, output_states, skip_sample class SkipDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, output_scale_factor=np.sqrt(2.0), add_downsample=True, downsample_padding=1, ): super().__init__() self.resnets = nn.ModuleList([]) for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(in_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if add_downsample: self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, down=True, kernel="fir", ) self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None self.downsamplers = None self.skip_conv = None def forward(self, hidden_states, temb=None, skip_sample=None): output_states = () for resnet in self.resnets: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) hidden_states = self.skip_conv(skip_sample) + hidden_states output_states += (hidden_states,) return hidden_states, output_states, skip_sample class ResnetDownsampleBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, skip_time_act=False, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, down=True, ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) return hidden_states, output_states class SimpleCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, cross_attention_dim=1280, output_scale_factor=1.0, add_downsample=True, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() self.has_cross_attention = True resnets = [] attentions = [] self.attention_head_dim = attention_head_dim self.num_heads = out_channels // self.attention_head_dim for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=out_channels, cross_attention_dim=out_channels, heads=self.num_heads, dim_head=attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, down=True, ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, mask, cross_attention_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) return hidden_states, output_states class KDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 4, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: int = 32, add_downsample=False, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=groups, groups_out=groups_out, eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: # YiYi's comments- might be able to use FirDownsample2D, look into details later self.downsamplers = nn.ModuleList([KDownsample2D()]) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states, output_states class KCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, cross_attention_dim: int, dropout: float = 0.0, num_layers: int = 4, resnet_group_size: int = 32, add_downsample=True, attention_head_dim: int = 64, add_self_attention: bool = False, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=groups, groups_out=groups_out, eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) attentions.append( KAttentionBlock( out_channels, out_channels // attention_head_dim, attention_head_dim, cross_attention_dim=cross_attention_dim, temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, cross_attention_norm="layer_norm", group_size=resnet_group_size, ) ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_downsample: self.downsamplers = nn.ModuleList([KDownsample2D()]) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, temb, attention_mask, cross_attention_kwargs, encoder_attention_mask, **ckpt_kwargs, ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if self.downsamplers is None: output_states += (None,) else: output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states, output_states class AttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, upsample_type="conv", ): super().__init__() resnets = [] attentions = [] self.upsample_type = upsample_type if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if upsample_type == "conv": self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) elif upsample_type == "resnet": self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, up=True, ) ] ) else: self.upsamplers = None def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: if self.upsample_type == "resnet": hidden_states = upsampler(hidden_states, temb=temb) else: hidden_states = upsampler(hidden_states) return hidden_states class CrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, posemb: Optional = None, ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, None, # timestep None, # class_labels posemb, cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, posemb=posemb, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpDecoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, temb_channels=None, ): super().__init__() resnets = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None def forward(self, hidden_states, temb=None): for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class AttnUpDecoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, add_upsample=True, temb_channels=None, ): super().__init__() resnets = [] attentions = [] if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None def forward(self, hidden_states, temb=None): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=temb) hidden_states = attn(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class AttnSkipUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=np.sqrt(2.0), add_upsample=True, ): super().__init__() self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(resnet_in_channels + res_skip_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." ) attention_head_dim = out_channels self.attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=32, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, up=True, kernel="fir", ) self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.skip_norm = torch.nn.GroupNorm( num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True ) self.act = nn.SiLU() else: self.resnet_up = None self.skip_conv = None self.skip_norm = None self.act = None def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = self.attentions[0](hidden_states) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) else: skip_sample = 0 if self.resnet_up is not None: skip_sample_states = self.skip_norm(hidden_states) skip_sample_states = self.act(skip_sample_states) skip_sample_states = self.skip_conv(skip_sample_states) skip_sample = skip_sample + skip_sample_states hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample class SkipUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, output_scale_factor=np.sqrt(2.0), add_upsample=True, upsample_padding=1, ): super().__init__() self.resnets = nn.ModuleList([]) for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min((resnet_in_channels + res_skip_channels) // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, up=True, kernel="fir", ) self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.skip_norm = torch.nn.GroupNorm( num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True ) self.act = nn.SiLU() else: self.resnet_up = None self.skip_conv = None self.skip_norm = None self.act = None def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) else: skip_sample = 0 if self.resnet_up is not None: skip_sample_states = self.skip_norm(hidden_states) skip_sample_states = self.act(skip_sample_states) skip_sample_states = self.skip_conv(skip_sample_states) skip_sample = skip_sample + skip_sample_states hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample class ResnetUpsampleBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, skip_time_act=False, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, up=True, ) ] ) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, temb) return hidden_states class SimpleCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.attention_head_dim = attention_head_dim self.num_heads = out_channels // self.attention_head_dim for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=out_channels, cross_attention_dim=out_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, up=True, ) ] ) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): # resnet # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, mask, cross_attention_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, temb) return hidden_states class KUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 5, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: Optional[int] = 32, add_upsample=True, ): super().__init__() resnets = [] k_in_channels = 2 * out_channels k_out_channels = in_channels num_layers = num_layers - 1 for i in range(num_layers): in_channels = k_in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=k_out_channels if (i == num_layers - 1) else out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([KUpsample2D()]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class KCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 4, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: int = 32, attention_head_dim=1, # attention dim_head cross_attention_dim: int = 768, add_upsample: bool = True, upcast_attention: bool = False, ): super().__init__() resnets = [] attentions = [] is_first_block = in_channels == out_channels == temb_channels is_middle_block = in_channels != out_channels add_self_attention = True if is_first_block else False self.has_cross_attention = True self.attention_head_dim = attention_head_dim # in_channels, and out_channels for the block (k-unet) k_in_channels = out_channels if is_first_block else 2 * out_channels k_out_channels = in_channels num_layers = num_layers - 1 for i in range(num_layers): in_channels = k_in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size if is_middle_block and (i == num_layers - 1): conv_2d_out_channels = k_out_channels else: conv_2d_out_channels = None resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, conv_2d_out_channels=conv_2d_out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) attentions.append( KAttentionBlock( k_out_channels if (i == num_layers - 1) else out_channels, k_out_channels // attention_head_dim if (i == num_layers - 1) else out_channels // attention_head_dim, attention_head_dim, cross_attention_dim=cross_attention_dim, temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, cross_attention_norm="layer_norm", upcast_attention=upcast_attention, ) ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_upsample: self.upsamplers = nn.ModuleList([KUpsample2D()]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, temb, attention_mask, cross_attention_kwargs, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states # can potentially later be renamed to `No-feed-forward` attention class KAttentionBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, upcast_attention: bool = False, temb_channels: int = 768, # for ada_group_norm add_self_attention: bool = False, cross_attention_norm: Optional[str] = None, group_size: int = 32, ): super().__init__() self.add_self_attention = add_self_attention # 1. Self-Attn if add_self_attention: self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=None, cross_attention_norm=None, ) # 2. Cross-Attn self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, cross_attention_norm=cross_attention_norm, ) def _to_3d(self, hidden_states, height, weight): return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) def _to_4d(self, hidden_states, height, weight): return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, # TODO: mark emb as non-optional (self.norm2 requires it). # requires assessing impact of change to positional param interface. emb: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} # 1. Self-Attention if self.add_self_attention: norm_hidden_states = self.norm1(hidden_states, emb) height, weight = norm_hidden_states.shape[2:] norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=None, attention_mask=attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) hidden_states = attn_output + hidden_states # 2. Cross-Attention/None norm_hidden_states = self.norm2(hidden_states, emb) height, weight = norm_hidden_states.shape[2:] norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) hidden_states = attn_output + hidden_states return hidden_states ================================================ FILE: 6DoF/diffusers/models/unet_2d_blocks_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import flax.linen as nn import jax.numpy as jnp from .attention_flax import FlaxTransformer2DModel from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D class FlaxCrossAttnDownBlock2D(nn.Module): r""" Cross Attention 2D Downsizing block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers num_attention_heads (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 num_attention_heads: int = 1 add_downsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] attentions = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads, depth=1, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) self.resnets = resnets self.attentions = attentions if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb, deterministic=deterministic) hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) output_states += (hidden_states,) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class FlaxDownBlock2D(nn.Module): r""" Flax 2D downsizing block Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 add_downsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, deterministic=True): output_states = () for resnet in self.resnets: hidden_states = resnet(hidden_states, temb, deterministic=deterministic) output_states += (hidden_states,) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class FlaxCrossAttnUpBlock2D(nn.Module): r""" Cross Attention 2D Upsampling block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers num_attention_heads (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsampling layer before each final output use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int prev_output_channel: int dropout: float = 0.0 num_layers: int = 1 num_attention_heads: int = 1 add_upsample: bool = True use_linear_projection: bool = False only_cross_attention: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] attentions = [] for i in range(self.num_layers): res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) attn_block = FlaxTransformer2DModel( in_channels=self.out_channels, n_heads=self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads, depth=1, use_linear_projection=self.use_linear_projection, only_cross_attention=self.only_cross_attention, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) self.resnets = resnets self.attentions = attentions if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUpBlock2D(nn.Module): r""" Flax 2D upsampling block Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels prev_output_channel (:obj:`int`): Output channels from the previous block dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsampling layer before each final output dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int prev_output_channel: int dropout: float = 0.0 num_layers: int = 1 add_upsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=self.out_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUNetMidBlock2DCrossAttn(nn.Module): r""" Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 Parameters: in_channels (:obj:`int`): Input channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of attention blocks layers num_attention_heads (:obj:`int`, *optional*, defaults to 1): Number of attention heads of each spatial transformer block use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): enable memory efficient attention https://arxiv.org/abs/2112.05682 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dropout: float = 0.0 num_layers: int = 1 num_attention_heads: int = 1 use_linear_projection: bool = False use_memory_efficient_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout_prob=self.dropout, dtype=self.dtype, ) ] attentions = [] for _ in range(self.num_layers): attn_block = FlaxTransformer2DModel( in_channels=self.in_channels, n_heads=self.num_attention_heads, d_head=self.in_channels // self.num_attention_heads, depth=1, use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) attentions.append(attn_block) res_block = FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout_prob=self.dropout, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets self.attentions = attentions def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) hidden_states = resnet(hidden_states, temb, deterministic=deterministic) return hidden_states ================================================ FILE: 6DoF/diffusers/models/unet_2d_condition.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .activations import get_activation from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from .modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, UpBlock2D, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) ================================================ FILE: 6DoF/diffusers/models/unet_2d_condition_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..utils import BaseOutput from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .modeling_flax_utils import FlaxModelMixin from .unet_2d_blocks_flax import ( FlaxCrossAttnDownBlock2D, FlaxCrossAttnUpBlock2D, FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, FlaxUpBlock2D, ) @flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): """ The output of [`FlaxUNet2DConditionModel`]. Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: jnp.ndarray @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its general usage and behavior. Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: sample_size (`int`, *optional*): The size of the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int` or `Tuple[int]`, *optional*): The number of attention heads. cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682). """ sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ) up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: Union[int, Tuple[int]] = 8 num_attention_heads: Optional[Union[int, Tuple[int]]] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 flip_sin_to_cos: bool = True freq_shift: int = 0 use_memory_efficient_attention: bool = False def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] def setup(self): block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 if self.num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = self.num_attention_heads or self.attention_head_dim # input self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # time self.time_proj = FlaxTimesteps( block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift ) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) only_cross_attention = self.only_cross_attention if isinstance(only_cross_attention, bool): only_cross_attention = (only_cross_attention,) * len(self.down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(self.down_block_types) # down down_blocks = [] output_channel = block_out_channels[0] for i, down_block_type in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 if down_block_type == "CrossAttnDownBlock2D": down_block = FlaxCrossAttnDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, num_attention_heads=num_attention_heads[i], add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) else: down_block = FlaxDownBlock2D( in_channels=input_channel, out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) self.down_blocks = down_blocks # mid self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], dropout=self.dropout, num_attention_heads=num_attention_heads[-1], use_linear_projection=self.use_linear_projection, use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) # up up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] is_final_block = i == len(block_out_channels) - 1 if up_block_type == "CrossAttnUpBlock2D": up_block = FlaxCrossAttnUpBlock2D( in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, num_attention_heads=reversed_num_attention_heads[i], add_upsample=not is_final_block, dropout=self.dropout, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], use_memory_efficient_attention=self.use_memory_efficient_attention, dtype=self.dtype, ) else: up_block = FlaxUpBlock2D( in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, add_upsample=not is_final_block, dropout=self.dropout, dtype=self.dtype, ) up_blocks.append(up_block) prev_output_channel = output_channel self.up_blocks = up_blocks # out self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.conv_out = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__( self, sample, timesteps, encoder_hidden_states, down_block_additional_residuals=None, mid_block_additional_residual=None, return_dict: bool = True, train: bool = False, ) -> Union[FlaxUNet2DConditionOutput, Tuple]: r""" Args: sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 1. time if not isinstance(timesteps, jnp.ndarray): timesteps = jnp.array([timesteps], dtype=jnp.int32) elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: timesteps = timesteps.astype(dtype=jnp.float32) timesteps = jnp.expand_dims(timesteps, 0) t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) # 2. pre-process sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for down_block in self.down_blocks: if isinstance(down_block, FlaxCrossAttnDownBlock2D): sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) else: sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample += down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) if mid_block_additional_residual is not None: sample += mid_block_additional_residual # 5. up for up_block in self.up_blocks: res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] if isinstance(up_block, FlaxCrossAttnUpBlock2D): sample = up_block( sample, temb=t_emb, encoder_hidden_states=encoder_hidden_states, res_hidden_states_tuple=res_samples, deterministic=not train, ) else: sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) # 6. post-process sample = self.conv_norm_out(sample) sample = nn.silu(sample) sample = self.conv_out(sample) sample = jnp.transpose(sample, (0, 3, 1, 2)) if not return_dict: return (sample,) return FlaxUNet2DConditionOutput(sample=sample) ================================================ FILE: 6DoF/diffusers/models/unet_3d_blocks.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D from .transformer_2d import Transformer2DModel from .transformer_temporal import TransformerTemporalModel def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, num_attention_heads, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=True, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", ): if down_block_type == "DownBlock3D": return DownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") return CrossAttnDownBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, num_attention_heads, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=True, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", ): if up_block_type == "UpBlock3D": return UpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") return CrossAttnUpBlock3D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock3DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=True, upcast_attention=False, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] temp_convs = [ TemporalConvLayer( in_channels, in_channels, dropout=0.1, ) ] attentions = [] temp_attentions = [] for _ in range(num_layers): attentions.append( Transformer2DModel( in_channels // num_attention_heads, num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) ) temp_attentions.append( TransformerTemporalModel( in_channels // num_attention_heads, num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( in_channels, in_channels, dropout=0.1, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) def forward( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, num_frames=1, cross_attention_kwargs=None, ): hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) for attn, temp_attn, resnet, temp_conv in zip( self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] ): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False )[0] hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) return hidden_states class CrossAttnDownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] temp_attentions = [] temp_convs = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, ) ) attentions.append( Transformer2DModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) temp_attentions.append( TransformerTemporalModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, num_frames=1, cross_attention_kwargs=None, ): # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions ): hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False )[0] output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class DownBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] temp_convs = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None, num_frames=1): output_states = () for resnet, temp_conv in zip(self.resnets, self.temp_convs): hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnUpBlock3D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] temp_convs = [] attentions = [] temp_attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, ) ) attentions.append( Transformer2DModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) temp_attentions.append( TransformerTemporalModel( out_channels // num_attention_heads, num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) self.attentions = nn.ModuleList(attentions) self.temp_attentions = nn.ModuleList(temp_attentions) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None, attention_mask=None, num_frames=1, cross_attention_kwargs=None, ): # TODO(Patrick, William) - attention mask is not used for resnet, temp_conv, attn, temp_attn in zip( self.resnets, self.temp_convs, self.attentions, self.temp_attentions ): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] hidden_states = temp_attn( hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlock3D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, ): super().__init__() resnets = [] temp_convs = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) temp_convs.append( TemporalConvLayer( out_channels, out_channels, dropout=0.1, ) ) self.resnets = nn.ModuleList(resnets) self.temp_convs = nn.ModuleList(temp_convs) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1): for resnet, temp_conv in zip(self.resnets, self.temp_convs): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = temp_conv(hidden_states, num_frames=num_frames) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states ================================================ FILE: 6DoF/diffusers/models/unet_3d_condition.py ================================================ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. # Copyright 2023 The ModelScope Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .transformer_temporal import TransformerTemporalModel from .unet_3d_blocks import ( CrossAttnDownBlock3D, CrossAttnUpBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, UpBlock3D, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet3DConditionOutput(BaseOutput): """ The output of [`UNet3DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. """ _supports_gradient_checkpointing = False @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, down_block_types: Tuple[str] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1024, attention_head_dim: Union[int, Tuple[int]] = 64, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise NotImplementedError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # input conv_in_kernel = 3 conv_out_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], True, 0) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) self.transformer_in = TransformerTemporalModel( num_attention_heads=8, attention_head_dim=attention_head_dim, in_channels=block_out_channels[0], num_layers=1, ) # class embedding self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=False, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock3DCrossAttn( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=False, ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=layers_per_block + 1, in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=False, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = nn.SiLU() else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def enable_forward_chunking(self, chunk_size=None, dim=0): """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) def disable_forward_chunking(self): def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, None, 0) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" The [`UNet3DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. Returns: [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML num_frames = sample.shape[2] timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) emb = emb.repeat_interleave(repeats=num_frames, dim=0) encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) # 2. pre-process sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = self.conv_in(sample) sample = self.transformer_in( sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples += (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, upsample_size=upsample_size, attention_mask=attention_mask, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, num_frames=num_frames, ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) # reshape to (batch, channel, framerate, width, height) sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) ================================================ FILE: 6DoF/diffusers/models/vae.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional import numpy as np import torch import torch.nn as nn from ..utils import BaseOutput, is_torch_version, randn_tensor from .attention_processor import SpatialNorm from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass class DecoderOutput(BaseOutput): """ Output of decoding method. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The decoded output sample from the last layer of the model. """ sample: torch.FloatTensor class Encoder(nn.Module): def __init__( self, in_channels=3, out_channels=3, down_block_types=("DownEncoderBlock2D",), block_out_channels=(64,), layers_per_block=2, norm_num_groups=32, act_fn="silu", double_z=True, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = torch.nn.Conv2d( in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, ) self.mid_block = None self.down_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=not is_final_block, resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=None, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, ) # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) self.gradient_checkpointing = False def forward(self, x): sample = x sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # down if is_torch_version(">=", "1.11.0"): for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), sample, use_reentrant=False ) # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, use_reentrant=False ) else: for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) # middle sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) else: # down for down_block in self.down_blocks: sample = down_block(sample) # middle sample = self.mid_block(sample) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class Decoder(nn.Module): def __init__( self, in_channels=3, out_channels=3, up_block_types=("UpDecoderBlock2D",), block_out_channels=(64,), layers_per_block=2, norm_num_groups=32, act_fn="silu", norm_type="group", # group, spatial ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = nn.Conv2d( in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1, ) self.mid_block = None self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, prev_output_channel=None, add_upsample=not is_final_block, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=temb_channels, resnet_time_scale_shift=norm_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.gradient_checkpointing = False def forward(self, z, latent_embeds=None): sample = z sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False ) else: # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) else: # middle sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class VectorQuantizer(nn.Module): """ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix multiplications and allows for post-hoc remapping of indices. """ # NOTE: due to a bug the beta term was applied to the wrong term. for # backwards compatibility we use the buggy version by default, but you can # specify legacy=False to fix it. def __init__( self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True ): super().__init__() self.n_e = n_e self.vq_embed_dim = vq_embed_dim self.beta = beta self.legacy = legacy self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] self.unknown_index = unknown_index # "random" or "extra" or integer if self.unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 print( f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices." ) else: self.re_embed = n_e self.sane_index_shape = sane_index_shape def remap_to_used(self, inds): ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) match = (inds[:, :, None] == used[None, None, ...]).long() new = match.argmax(-1) unknown = match.sum(2) < 1 if self.unknown_index == "random": new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) else: new[unknown] = self.unknown_index return new.reshape(ishape) def unmap_to_all(self, inds): ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) if self.re_embed > self.used.shape[0]: # extra token inds[inds >= self.used.shape[0]] = 0 # simply set to zero back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) return back.reshape(ishape) def forward(self, z): # reshape z -> (batch, height, width, channel) and flatten z = z.permute(0, 2, 3, 1).contiguous() z_flattened = z.view(-1, self.vq_embed_dim) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) perplexity = None min_encodings = None # compute loss for embedding if not self.legacy: loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) else: loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() if self.remap is not None: min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis min_encoding_indices = self.remap_to_used(min_encoding_indices) min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten if self.sane_index_shape: min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) return z_q, loss, (perplexity, min_encodings, min_encoding_indices) def get_codebook_entry(self, indices, shape): # shape specifying (batch, height, width, channel) if self.remap is not None: indices = indices.reshape(shape[0], -1) # add batch axis indices = self.unmap_to_all(indices) indices = indices.reshape(-1) # flatten again # get quantized latent vectors z_q = self.embedding(indices) if shape is not None: z_q = z_q.view(shape) # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() return z_q class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like( self.mean, device=self.parameters.device, dtype=self.parameters.dtype ) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: # make sure sample is on the same device as the parameters and has same dtype sample = randn_tensor( self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype ) x = self.mean + self.std * sample return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean ================================================ FILE: 6DoF/diffusers/models/vae_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers import math from functools import partial from typing import Tuple import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..utils import BaseOutput from .modeling_flax_utils import FlaxModelMixin @flax.struct.dataclass class FlaxDecoderOutput(BaseOutput): """ Output of decoding method. Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): The decoded output sample from the last layer of the model. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): The `dtype` of the parameters. """ sample: jnp.ndarray @flax.struct.dataclass class FlaxAutoencoderKLOutput(BaseOutput): """ Output of AutoencoderKL encoding method. Args: latent_dist (`FlaxDiagonalGaussianDistribution`): Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. """ latent_dist: "FlaxDiagonalGaussianDistribution" class FlaxUpsample2D(nn.Module): """ Flax implementation of 2D Upsample layer Args: in_channels (`int`): Input channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, hidden_states): batch, height, width, channels = hidden_states.shape hidden_states = jax.image.resize( hidden_states, shape=(batch, height * 2, width * 2, channels), method="nearest", ) hidden_states = self.conv(hidden_states) return hidden_states class FlaxDownsample2D(nn.Module): """ Flax implementation of 2D Downsample layer Args: in_channels (`int`): Input channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), strides=(2, 2), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states): pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim hidden_states = jnp.pad(hidden_states, pad_width=pad) hidden_states = self.conv(hidden_states) return hidden_states class FlaxResnetBlock2D(nn.Module): """ Flax implementation of 2D Resnet Block. Args: in_channels (`int`): Input channels out_channels (`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm. use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): Whether to use `nin_shortcut`. This activates a new layer inside ResNet block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int = None dropout: float = 0.0 groups: int = 32 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.dropout_layer = nn.Dropout(self.dropout) self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut self.conv_shortcut = None if use_nin_shortcut: self.conv_shortcut = nn.Conv( out_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states, deterministic=True): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return hidden_states + residual class FlaxAttentionBlock(nn.Module): r""" Flax Convolutional based multi-head attention block for diffusion-based VAE. Parameters: channels (:obj:`int`): Input channels num_head_channels (:obj:`int`, *optional*, defaults to `None`): Number of attention heads num_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ channels: int num_head_channels: int = None num_groups: int = 32 dtype: jnp.dtype = jnp.float32 def setup(self): self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 dense = partial(nn.Dense, self.channels, dtype=self.dtype) self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) self.query, self.key, self.value = dense(), dense(), dense() self.proj_attn = dense() def transpose_for_scores(self, projection): new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) new_projection = projection.reshape(new_projection_shape) # (B, T, H, D) -> (B, H, T, D) new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) return new_projection def __call__(self, hidden_states): residual = hidden_states batch, height, width, channels = hidden_states.shape hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.reshape((batch, height * width, channels)) query = self.query(hidden_states) key = self.key(hidden_states) value = self.value(hidden_states) # transpose query = self.transpose_for_scores(query) key = self.transpose_for_scores(key) value = self.transpose_for_scores(value) # compute attentions scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) attn_weights = nn.softmax(attn_weights, axis=-1) # attend to values hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) hidden_states = hidden_states.reshape(new_hidden_states_shape) hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.reshape((batch, height, width, channels)) hidden_states = hidden_states + residual return hidden_states class FlaxDownEncoderBlock2D(nn.Module): r""" Flax Resnet blocks-based Encoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 add_downsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) return hidden_states class FlaxUpDecoderBlock2D(nn.Module): r""" Flax Resnet blocks-based Decoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 add_upsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUNetMidBlock2D(nn.Module): r""" Flax Unet Mid-Block module. Parameters: in_channels (:obj:`int`): Input channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet and Attention block group norm num_attention_heads (:obj:`int`, *optional*, defaults to `1`): Number of attention heads for each attention block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 num_attention_heads: int = 1 dtype: jnp.dtype = jnp.float32 def setup(self): resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, groups=resnet_groups, dtype=self.dtype, ) ] attentions = [] for _ in range(self.num_layers): attn_block = FlaxAttentionBlock( channels=self.in_channels, num_head_channels=self.num_attention_heads, num_groups=resnet_groups, dtype=self.dtype, ) attentions.append(attn_block) res_block = FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, groups=resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets self.attentions = attentions def __call__(self, hidden_states, deterministic=True): hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, deterministic=deterministic) return hidden_states class FlaxEncoder(nn.Module): r""" Flax Implementation of VAE Encoder. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): DownEncoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) block_out_channels: Tuple[int] = (64,) layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" double_z: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels # in self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # downsampling down_blocks = [] output_channel = block_out_channels[0] for i, _ in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = FlaxDownEncoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers=self.layers_per_block, resnet_groups=self.norm_num_groups, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) self.down_blocks = down_blocks # middle self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_groups=self.norm_num_groups, num_attention_heads=None, dtype=self.dtype, ) # end conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( conv_out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, sample, deterministic: bool = True): # in sample = self.conv_in(sample) # downsampling for block in self.down_blocks: sample = block(sample, deterministic=deterministic) # middle sample = self.mid_block(sample, deterministic=deterministic) # end sample = self.conv_norm_out(sample) sample = nn.swish(sample) sample = self.conv_out(sample) return sample class FlaxDecoder(nn.Module): r""" Flax Implementation of VAE Decoder. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): UpDecoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 up_block_types: Tuple[str] = ("UpDecoderBlock2D",) block_out_channels: int = (64,) layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels # z to block_in self.conv_in = nn.Conv( block_out_channels[-1], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # middle self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_groups=self.norm_num_groups, num_attention_heads=None, dtype=self.dtype, ) # upsampling reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] up_blocks = [] for i, _ in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = FlaxUpDecoderBlock2D( in_channels=prev_output_channel, out_channels=output_channel, num_layers=self.layers_per_block + 1, resnet_groups=self.norm_num_groups, add_upsample=not is_final_block, dtype=self.dtype, ) up_blocks.append(up_block) prev_output_channel = output_channel self.up_blocks = up_blocks # end self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, sample, deterministic: bool = True): # z to block_in sample = self.conv_in(sample) # middle sample = self.mid_block(sample, deterministic=deterministic) # upsampling for block in self.up_blocks: sample = block(sample, deterministic=deterministic) sample = self.conv_norm_out(sample) sample = nn.swish(sample) sample = self.conv_out(sample) return sample class FlaxDiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): # Last axis to account for channels-last self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) self.logvar = jnp.clip(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = jnp.exp(0.5 * self.logvar) self.var = jnp.exp(self.logvar) if self.deterministic: self.var = self.std = jnp.zeros_like(self.mean) def sample(self, key): return self.mean + self.std * jax.random.normal(key, self.mean.shape) def kl(self, other=None): if self.deterministic: return jnp.array([0.0]) if other is None: return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) return 0.5 * jnp.sum( jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, axis=[1, 2, 3], ) def nll(self, sample, axis=[1, 2, 3]): if self.deterministic: return jnp.array([0.0]) logtwopi = jnp.log(2.0 * jnp.pi) return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) def mode(self): return self.mean @flax_register_to_config class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): r""" Flax implementation of a VAE model with KL loss for decoding latent representations. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its general usage and behavior. Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): Tuple of upsample block types. block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple of block output channels. layers_per_block (`int`, *optional*, defaults to `2`): Number of ResNet layer for each block. act_fn (`str`, *optional*, defaults to `silu`): The activation function to use. latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. sample_size (`int`, *optional*, defaults to 32): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): The `dtype` of the parameters. """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) up_block_types: Tuple[str] = ("UpDecoderBlock2D",) block_out_channels: Tuple[int] = (64,) layers_per_block: int = 1 act_fn: str = "silu" latent_channels: int = 4 norm_num_groups: int = 32 sample_size: int = 32 scaling_factor: float = 0.18215 dtype: jnp.dtype = jnp.float32 def setup(self): self.encoder = FlaxEncoder( in_channels=self.config.in_channels, out_channels=self.config.latent_channels, down_block_types=self.config.down_block_types, block_out_channels=self.config.block_out_channels, layers_per_block=self.config.layers_per_block, act_fn=self.config.act_fn, norm_num_groups=self.config.norm_num_groups, double_z=True, dtype=self.dtype, ) self.decoder = FlaxDecoder( in_channels=self.config.latent_channels, out_channels=self.config.out_channels, up_block_types=self.config.up_block_types, block_out_channels=self.config.block_out_channels, layers_per_block=self.config.layers_per_block, norm_num_groups=self.config.norm_num_groups, act_fn=self.config.act_fn, dtype=self.dtype, ) self.quant_conv = nn.Conv( 2 * self.config.latent_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.post_quant_conv = nn.Conv( self.config.latent_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} return self.init(rngs, sample)["params"] def encode(self, sample, deterministic: bool = True, return_dict: bool = True): sample = jnp.transpose(sample, (0, 2, 3, 1)) hidden_states = self.encoder(sample, deterministic=deterministic) moments = self.quant_conv(hidden_states) posterior = FlaxDiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) def decode(self, latents, deterministic: bool = True, return_dict: bool = True): if latents.shape[-1] != self.config.latent_channels: latents = jnp.transpose(latents, (0, 2, 3, 1)) hidden_states = self.post_quant_conv(latents) hidden_states = self.decoder(hidden_states, deterministic=deterministic) hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) if not return_dict: return (hidden_states,) return FlaxDecoderOutput(sample=hidden_states) def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) if sample_posterior: rng = self.make_rng("gaussian") hidden_states = posterior.latent_dist.sample(rng) else: hidden_states = posterior.latent_dist.mode() sample = self.decode(hidden_states, return_dict=return_dict).sample if not return_dict: return (sample,) return FlaxDecoderOutput(sample=sample) ================================================ FILE: 6DoF/diffusers/models/vq_model.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, apply_forward_hook from .modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer @dataclass class VQEncoderOutput(BaseOutput): """ Output of VQModel encoding method. Args: latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The encoded output sample from the last layer of the model. """ latents: torch.FloatTensor class VQModel(ModelMixin, ConfigMixin): r""" A VQ-VAE model for decoding latent representations. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. scaling_factor (`float`, *optional*, defaults to `0.18215`): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 3, sample_size: int = 32, num_vq_embeddings: int = 256, norm_num_groups: int = 32, vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, norm_type: str = "group", # group, spatial ): super().__init__() # pass init params to Encoder self.encoder = Encoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=False, ) vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) # pass init params to Decoder self.decoder = Decoder( in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, norm_type=norm_type, ) @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: h = self.encoder(x) h = self.quant_conv(h) if not return_dict: return (h,) return VQEncoderOutput(latents=h) @apply_forward_hook def decode( self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True ) -> Union[DecoderOutput, torch.FloatTensor]: # also go through quantization layer if not force_not_quantize: quant, emb_loss, info = self.quantize(h) else: quant = h quant2 = self.post_quant_conv(quant) dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) if not return_dict: return (dec,) return DecoderOutput(sample=dec) def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: r""" The [`VQModel`] forward method. Args: sample (`torch.FloatTensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. Returns: [`~models.vq_model.VQEncoderOutput`] or `tuple`: If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` is returned. """ x = sample h = self.encode(x).latents dec = self.decode(h).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec) ================================================ FILE: 6DoF/diffusers/optimization.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch optimization for diffusion models.""" import math from enum import Enum from typing import Optional, Union from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR from .utils import logging logger = logging.get_logger(__name__) class SchedulerType(Enum): LINEAR = "linear" COSINE = "cosine" COSINE_WITH_RESTARTS = "cosine_with_restarts" POLYNOMIAL = "polynomial" CONSTANT = "constant" CONSTANT_WITH_WARMUP = "constant_with_warmup" PIECEWISE_CONSTANT = "piecewise_constant" def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): """ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1.0, num_warmup_steps)) return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1): """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. step_rules (`string`): The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ rules_dict = {} rule_list = step_rules.split(",") for rule_str in rule_list[:-1]: value_str, steps_str = rule_str.split(":") steps = int(steps_str) value = float(value_str) rules_dict[steps] = value last_lr_multiple = float(rule_list[-1]) def create_rules_function(rules_dict, last_lr_multiple): def rule_func(steps: int) -> float: sorted_steps = sorted(rules_dict.keys()) for i, sorted_step in enumerate(sorted_steps): if steps < sorted_step: return rules_dict[sorted_steps[i]] return last_lr_multiple return rule_func rules_func = create_rules_function(rules_dict, last_lr_multiple) return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) ) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_cosine_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_periods (`float`, *optional*, defaults to 0.5): The number of periods of the cosine function in a schedule (the default is to just decrease from the max value to 0 following a half-cosine). last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_cosine_with_hard_restarts_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_cycles (`int`, *optional*, defaults to 1): The number of hard restarts to use. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) if progress >= 1.0: return 0.0 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_polynomial_decay_schedule_with_warmup( optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 ): """ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. lr_end (`float`, *optional*, defaults to 1e-7): The end LR. power (`float`, *optional*, defaults to 1.0): Power factor. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_init = optimizer.defaults["lr"] if not (lr_init > lr_end): raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step > num_training_steps: return lr_end / lr_init # as LambdaLR multiplies by lr_init else: lr_range = lr_init - lr_end decay_steps = num_training_steps - num_warmup_steps pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps decay = lr_range * pct_remaining**power + lr_end return decay / lr_init # as LambdaLR multiplies by lr_init return LambdaLR(optimizer, lr_lambda, last_epoch) TYPE_TO_SCHEDULER_FUNCTION = { SchedulerType.LINEAR: get_linear_schedule_with_warmup, SchedulerType.COSINE: get_cosine_schedule_with_warmup, SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, SchedulerType.CONSTANT: get_constant_schedule, SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule, } def get_scheduler( name: Union[str, SchedulerType], optimizer: Optimizer, step_rules: Optional[str] = None, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: int = 1, power: float = 1.0, last_epoch: int = -1, ): """ Unified API to get any scheduler from its name. Args: name (`str` or `SchedulerType`): The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. step_rules (`str`, *optional*): A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. num_warmup_steps (`int`, *optional*): The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_training_steps (`int``, *optional*): The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_cycles (`int`, *optional*): The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. power (`float`, *optional*, defaults to 1.0): Power factor. See `POLYNOMIAL` scheduler last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return schedule_func(optimizer, last_epoch=last_epoch) if name == SchedulerType.PIECEWISE_CONSTANT: return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch) # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") if name == SchedulerType.COSINE_WITH_RESTARTS: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, last_epoch=last_epoch, ) if name == SchedulerType.POLYNOMIAL: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, last_epoch=last_epoch, ) return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch ) ================================================ FILE: 6DoF/diffusers/pipeline_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: This file is deprecated and will be removed in a future version. # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 from .utils import deprecate deprecate( "pipelines_utils", "0.22.0", "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.", standard_warn=False, stacklevel=3, ) ================================================ FILE: 6DoF/diffusers/pipelines/__init__.py ================================================ from ..utils import ( OptionalDependencyNotAvailable, is_flax_available, is_invisible_watermark_available, is_k_diffusion_available, is_librosa_available, is_note_seq_available, is_onnx_available, is_torch_available, is_transformers_available, ) try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: from .consistency_models import ConsistencyModelPipeline from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput from .pndm import PNDMPipeline from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline from .stochastic_karras_ve import KarrasVePipeline try: if not (is_torch_available() and is_librosa_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_librosa_objects import * # noqa F403 else: from .audio_diffusion import AudioDiffusionPipeline, Mel try: if not (is_torch_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .audioldm import AudioLDMPipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, ) from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, IFPipeline, IFSuperResolutionPipeline, ) from .kandinsky import ( KandinskyImg2ImgPipeline, KandinskyInpaintPipeline, KandinskyPipeline, KandinskyPriorPipeline, ) from .kandinsky2_2 import ( KandinskyV22ControlnetImg2ImgPipeline, KandinskyV22ControlnetPipeline, KandinskyV22Img2ImgPipeline, KandinskyV22InpaintPipeline, KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorPipeline, ) from .latent_diffusion import LDMTextToImagePipeline from .paint_by_example import PaintByExamplePipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_diffusion import ( CycleDiffusionPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionInstructPix2PixPipeline, StableDiffusionLatentUpscalePipeline, StableDiffusionLDM3DPipeline, StableDiffusionModelEditingPipeline, StableDiffusionPanoramaPipeline, StableDiffusionParadigmsPipeline, StableDiffusionPipeline, StableDiffusionPix2PixZeroPipeline, StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline try: if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 else: from .stable_diffusion_xl import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_onnx_objects import * # noqa F403 else: from .onnx_utils import OnnxRuntimeModel try: if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 else: from .stable_diffusion import ( OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionPipeline, OnnxStableDiffusionUpscalePipeline, StableDiffusionOnnxPipeline, ) try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 else: from .stable_diffusion import StableDiffusionKDiffusionPipeline try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_flax_objects import * # noqa F403 else: from .pipeline_flax_utils import FlaxDiffusionPipeline try: if not (is_flax_available() and is_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 else: from .controlnet import FlaxStableDiffusionControlNetPipeline from .stable_diffusion import ( FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, ) try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline ================================================ FILE: 6DoF/diffusers/pipelines/alt_diffusion/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import BaseOutput, is_torch_available, is_transformers_available @dataclass # Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt class AltDiffusionPipelineOutput(BaseOutput): """ Output class for Alt Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] if is_transformers_available() and is_torch_available(): from .modeling_roberta_series import RobertaSeriesModelWithTransformation from .pipeline_alt_diffusion import AltDiffusionPipeline from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline ================================================ FILE: 6DoF/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py ================================================ from dataclasses import dataclass from typing import Optional, Tuple import torch from torch import nn from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel from transformers.utils import ModelOutput @dataclass class TransformationModelOutput(ModelOutput): """ Base class for text model's outputs that also contains a pooling of the last hidden states. Args: text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The text embeddings obtained by applying the projection layer to the pooler_output. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ projection_state: Optional[torch.FloatTensor] = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None class RobertaSeriesConfig(XLMRobertaConfig): def __init__( self, pad_token_id=1, bos_token_id=0, eos_token_id=2, project_dim=512, pooler_fn="cls", learn_encoder=False, use_attention_mask=True, **kwargs, ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.project_dim = project_dim self.pooler_fn = pooler_fn self.learn_encoder = learn_encoder self.use_attention_mask = use_attention_mask class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] base_model_prefix = "roberta" config_class = RobertaSeriesConfig def __init__(self, config): super().__init__(config) self.roberta = XLMRobertaModel(config) self.transformation = nn.Linear(config.hidden_size, config.project_dim) self.has_pre_transformation = getattr(config, "has_pre_transformation", False) if self.has_pre_transformation: self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim) self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ): r""" """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.base_model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=True if self.has_pre_transformation else output_hidden_states, return_dict=return_dict, ) if self.has_pre_transformation: sequence_output2 = outputs["hidden_states"][-2] sequence_output2 = self.pre_LN(sequence_output2) projection_state2 = self.transformation_pre(sequence_output2) return TransformationModelOutput( projection_state=projection_state2, last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) else: projection_state = self.transformation(outputs.last_hidden_state) return TransformationModelOutput( projection_state=projection_state, last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) ================================================ FILE: 6DoF/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from packaging import version from transformers import CLIPImageProcessor, XLMRobertaTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AltDiffusionPipeline >>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap" >>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图" >>> image = pipe(prompt).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Alt Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`RobertaSeriesModelWithTransformation`]): Frozen text-encoder. Alt Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`XLMRobertaTokenizer`): Tokenizer of class [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: RobertaSeriesModelWithTransformation, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def decode_latents(self, latents): warnings.warn( ( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead" ), FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: Returns: [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, XLMRobertaTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import AltDiffusionImg2ImgPipeline >>> device = "cuda" >>> model_id_or_path = "BAAI/AltDiffusion-m9" >>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) >>> # "A fantasy landscape, trending on artstation" >>> prompt = "幻想风景, artstation" >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images >>> images[0].save("幻想风景.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionImg2ImgPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image to image generation using Alt Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`RobertaSeriesModelWithTransformation`]): Frozen text-encoder. Alt Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`XLMRobertaTokenizer`): Tokenizer of class [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: RobertaSeriesModelWithTransformation, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def decode_latents(self, latents): warnings.warn( ( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead" ), FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective" f" batch size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/audio_diffusion/__init__.py ================================================ from .mel import Mel from .pipeline_audio_diffusion import AudioDiffusionPipeline ================================================ FILE: 6DoF/diffusers/pipelines/audio_diffusion/mel.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np # noqa: E402 from ...configuration_utils import ConfigMixin, register_to_config from ...schedulers.scheduling_utils import SchedulerMixin try: import librosa # noqa: E402 _librosa_can_be_imported = True _import_error = "" except Exception as e: _librosa_can_be_imported = False _import_error = ( f"Cannot import librosa because {e}. Make sure to correctly install librosa to be able to install it." ) from PIL import Image # noqa: E402 class Mel(ConfigMixin, SchedulerMixin): """ Parameters: x_res (`int`): x resolution of spectrogram (time) y_res (`int`): y resolution of spectrogram (frequency bins) sample_rate (`int`): sample rate of audio n_fft (`int`): number of Fast Fourier Transforms hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res) top_db (`int`): loudest in decibels n_iter (`int`): number of iterations for Griffin Linn mel inversion """ config_name = "mel_config.json" @register_to_config def __init__( self, x_res: int = 256, y_res: int = 256, sample_rate: int = 22050, n_fft: int = 2048, hop_length: int = 512, top_db: int = 80, n_iter: int = 32, ): self.hop_length = hop_length self.sr = sample_rate self.n_fft = n_fft self.top_db = top_db self.n_iter = n_iter self.set_resolution(x_res, y_res) self.audio = None if not _librosa_can_be_imported: raise ValueError(_import_error) def set_resolution(self, x_res: int, y_res: int): """Set resolution. Args: x_res (`int`): x resolution of spectrogram (time) y_res (`int`): y resolution of spectrogram (frequency bins) """ self.x_res = x_res self.y_res = y_res self.n_mels = self.y_res self.slice_size = self.x_res * self.hop_length - 1 def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None): """Load audio. Args: audio_file (`str`): must be a file on disk due to Librosa limitation or raw_audio (`np.ndarray`): audio as numpy array """ if audio_file is not None: self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr) else: self.audio = raw_audio # Pad with silence if necessary. if len(self.audio) < self.x_res * self.hop_length: self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))]) def get_number_of_slices(self) -> int: """Get number of slices in audio. Returns: `int`: number of spectograms audio can be sliced into """ return len(self.audio) // self.slice_size def get_audio_slice(self, slice: int = 0) -> np.ndarray: """Get slice of audio. Args: slice (`int`): slice number of audio (out of get_number_of_slices()) Returns: `np.ndarray`: audio as numpy array """ return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)] def get_sample_rate(self) -> int: """Get sample rate: Returns: `int`: sample rate of audio """ return self.sr def audio_slice_to_image(self, slice: int) -> Image.Image: """Convert slice of audio to spectrogram. Args: slice (`int`): slice number of audio to convert (out of get_number_of_slices()) Returns: `PIL Image`: grayscale image of x_res x y_res """ S = librosa.feature.melspectrogram( y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels ) log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db) bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8) image = Image.fromarray(bytedata) return image def image_to_audio(self, image: Image.Image) -> np.ndarray: """Converts spectrogram to audio. Args: image (`PIL Image`): x_res x y_res grayscale image Returns: audio (`np.ndarray`): raw audio """ bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width)) log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db S = librosa.db_to_power(log_S) audio = librosa.feature.inverse.mel_to_audio( S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter ) return audio ================================================ FILE: 6DoF/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from math import acos, sin from typing import List, Tuple, Union import numpy as np import torch from PIL import Image from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, DDPMScheduler from ...utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput from .mel import Mel class AudioDiffusionPipeline(DiffusionPipeline): """ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqae ([`AutoencoderKL`]): Variational AutoEncoder for Latent Audio Diffusion or None unet ([`UNet2DConditionModel`]): UNET model mel ([`Mel`]): transform audio <-> spectrogram scheduler ([`DDIMScheduler` or `DDPMScheduler`]): de-noising scheduler """ _optional_components = ["vqvae"] def __init__( self, vqvae: AutoencoderKL, unet: UNet2DConditionModel, mel: Mel, scheduler: Union[DDIMScheduler, DDPMScheduler], ): super().__init__() self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae) def get_default_steps(self) -> int: """Returns default number of steps recommended for inference Returns: `int`: number of steps """ return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000 @torch.no_grad() def __call__( self, batch_size: int = 1, audio_file: str = None, raw_audio: np.ndarray = None, slice: int = 0, start_step: int = 0, steps: int = None, generator: torch.Generator = None, mask_start_secs: float = 0, mask_end_secs: float = 0, step_generator: torch.Generator = None, eta: float = 0, noise: torch.Tensor = None, encoding: torch.Tensor = None, return_dict=True, ) -> Union[ Union[AudioPipelineOutput, ImagePipelineOutput], Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]], ]: """Generate random mel spectrogram from audio input and convert to audio. Args: batch_size (`int`): number of samples to generate audio_file (`str`): must be a file on disk due to Librosa limitation or raw_audio (`np.ndarray`): audio as numpy array slice (`int`): slice number of audio to convert start_step (int): step to start from steps (`int`): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM) generator (`torch.Generator`): random number generator or None mask_start_secs (`float`): number of seconds of audio to mask (not generate) at start mask_end_secs (`float`): number of seconds of audio to mask (not generate) at end step_generator (`torch.Generator`): random number generator used to de-noise or None eta (`float`): parameter between 0 and 1 used with DDIM scheduler noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim) return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple Returns: `List[PIL Image]`: mel spectrograms (`float`, `List[np.ndarray]`): sample rate and raw audios """ steps = steps or self.get_default_steps() self.scheduler.set_timesteps(steps) step_generator = step_generator or generator # For backwards compatibility if type(self.unet.config.sample_size) == int: self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) if noise is None: noise = randn_tensor( ( batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1], ), generator=generator, device=self.device, ) images = noise mask = None if audio_file is not None or raw_audio is not None: self.mel.load_audio(audio_file, raw_audio) input_image = self.mel.audio_slice_to_image(slice) input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape( (input_image.height, input_image.width) ) input_image = (input_image / 255) * 2 - 1 input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device) if self.vqvae is not None: input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample( generator=generator )[0] input_images = self.vqvae.config.scaling_factor * input_images if start_step > 0: images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) pixels_per_second = ( self.unet.config.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length ) mask_start = int(mask_start_secs * pixels_per_second) mask_end = int(mask_end_secs * pixels_per_second) mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:])) for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])): if isinstance(self.unet, UNet2DConditionModel): model_output = self.unet(images, t, encoding)["sample"] else: model_output = self.unet(images, t)["sample"] if isinstance(self.scheduler, DDIMScheduler): images = self.scheduler.step( model_output=model_output, timestep=t, sample=images, eta=eta, generator=step_generator, )["prev_sample"] else: images = self.scheduler.step( model_output=model_output, timestep=t, sample=images, generator=step_generator, )["prev_sample"] if mask is not None: if mask_start > 0: images[:, :, :, :mask_start] = mask[:, step, :, :mask_start] if mask_end > 0: images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:] if self.vqvae is not None: # 0.18215 was scaling factor used in training to ensure unit variance images = 1 / self.vqvae.config.scaling_factor * images images = self.vqvae.decode(images)["sample"] images = (images / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).numpy() images = (images * 255).round().astype("uint8") images = list( (Image.fromarray(_[:, :, 0]) for _ in images) if images.shape[3] == 1 else (Image.fromarray(_, mode="RGB").convert("L") for _ in images) ) audios = [self.mel.image_to_audio(_) for _ in images] if not return_dict: return images, (self.mel.get_sample_rate(), audios) return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images)) @torch.no_grad() def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray: """Reverse step process: recover noisy image from generated image. Args: images (`List[PIL Image]`): list of images to encode steps (`int`): number of encoding steps to perform (defaults to 50) Returns: `np.ndarray`: noise tensor of shape (batch_size, 1, height, width) """ # Only works with DDIM as this method is deterministic assert isinstance(self.scheduler, DDIMScheduler) self.scheduler.set_timesteps(steps) sample = np.array( [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images] ) sample = (sample / 255) * 2 - 1 sample = torch.Tensor(sample).to(self.device) for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t model_output = self.unet(sample, t)["sample"] pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5) sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output return sample @staticmethod def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor: """Spherical Linear intERPolation Args: x0 (`torch.Tensor`): first tensor to interpolate between x1 (`torch.Tensor`): seconds tensor to interpolate between alpha (`float`): interpolation between 0 and 1 Returns: `torch.Tensor`: interpolated tensor """ theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1)) return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta) ================================================ FILE: 6DoF/diffusers/pipelines/audioldm/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( AudioLDMPipeline, ) else: from .pipeline_audioldm import AudioLDMPipeline ================================================ FILE: 6DoF/diffusers/pipelines/audioldm/pipeline_audioldm.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch import torch.nn.functional as F from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import AudioLDMPipeline >>> pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "A hammer hitting a wooden surface" >>> audio = pipe(prompt).audio[0] ``` """ class AudioLDMPipeline(DiffusionPipeline): r""" Pipeline for text-to-audio generation using AudioLDM. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode audios to and from latent representations. text_encoder ([`ClapTextModelWithProjection`]): Frozen text-encoder. AudioLDM uses the text portion of [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap#transformers.ClapTextModelWithProjection), specifically the [RoBERTa HSTAT-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. tokenizer ([`PreTrainedTokenizer`]): Tokenizer of class [RobertaTokenizer](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer). unet ([`UNet2DConditionModel`]): U-Net architecture to denoise the encoded audio latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. vocoder ([`SpeechT5HifiGan`]): Vocoder of class [SpeechT5HifiGan](https://huggingface.co/docs/transformers/main/en/model_doc/speecht5#transformers.SpeechT5HifiGan). """ def __init__( self, vae: AutoencoderKL, text_encoder: ClapTextModelWithProjection, tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast], unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, vocoder: SpeechT5HifiGan, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, vocoder=vocoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and vocoder have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.vocoder]: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_waveforms_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device (`torch.device`): torch device num_waveforms_per_prompt (`int`): number of waveforms that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the audio generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLAP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask.to(device), ) prompt_embeds = prompt_embeds.text_embeds # additional L_2 normalization over each hidden-state prompt_embeds = F.normalize(prompt_embeds, dim=-1) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) ( bs_embed, seq_len, ) = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt) prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) uncond_input_ids = uncond_input.input_ids.to(device) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input_ids, attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds.text_embeds # additional L_2 normalization over each hidden-state negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents mel_spectrogram = self.vae.decode(latents).sample return mel_spectrogram def mel_spectrogram_to_waveform(self, mel_spectrogram): if mel_spectrogram.dim() == 4: mel_spectrogram = mel_spectrogram.squeeze(1) waveform = self.vocoder(mel_spectrogram) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 waveform = waveform.cpu().float() return waveform # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, audio_length_in_s, vocoder_upsample_factor, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor if audio_length_in_s < min_audio_length_in_s: raise ValueError( f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but " f"is {audio_length_in_s}." ) if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0: raise ValueError( f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the " f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of " f"{self.vae_scale_factor}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): shape = ( batch_size, num_channels_latents, height // self.vae_scale_factor, self.vocoder.config.model_in_dim // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, audio_length_in_s: Optional[float] = None, num_inference_steps: int = 10, guidance_scale: float = 2.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_waveforms_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, output_type: Optional[str] = "np", ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the audio generation. If not defined, one has to pass `prompt_embeds`. instead. audio_length_in_s (`int`, *optional*, defaults to 5.12): The length of the generated audio sample in seconds. num_inference_steps (`int`, *optional*, defaults to 10): The number of denoising steps. More denoising steps usually lead to a higher quality audio at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 2.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate audios that are closely linked to the text `prompt`, usually at the expense of lower sound quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the audio generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_waveforms_per_prompt (`int`, *optional*, defaults to 1): The number of waveforms to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). output_type (`str`, *optional*, defaults to `"np"`): The output format of the generate image. Choose between: - `"np"`: Return Numpy `np.ndarray` objects. - `"pt"`: Return PyTorch `torch.Tensor` objects. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated audios. """ # 0. Convert audio input length from seconds to spectrogram height vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate if audio_length_in_s is None: audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor height = int(audio_length_in_s / vocoder_upsample_factor) original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) if height % self.vae_scale_factor != 0: height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor logger.info( f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} " f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the " f"denoising process." ) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, audio_length_in_s, vocoder_upsample_factor, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_waveforms_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_waveforms_per_prompt, num_channels_latents, height, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=None, class_labels=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Post-processing mel_spectrogram = self.decode_latents(latents) audio = self.mel_spectrogram_to_waveform(mel_spectrogram) audio = audio[:, :original_waveform_length] if output_type == "np": audio = audio.numpy() if not return_dict: return (audio,) return AudioPipelineOutput(audios=audio) ================================================ FILE: 6DoF/diffusers/pipelines/consistency_models/__init__.py ================================================ from .pipeline_consistency_models import ConsistencyModelPipeline ================================================ FILE: 6DoF/diffusers/pipelines/consistency_models/pipeline_consistency_models.py ================================================ from typing import Callable, List, Optional, Union import torch from ...models import UNet2DModel from ...schedulers import CMStochasticIterativeScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import ConsistencyModelPipeline >>> device = "cuda" >>> # Load the cd_imagenet64_l2 checkpoint. >>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2" >>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe.to(device) >>> # Onestep Sampling >>> image = pipe(num_inference_steps=1).images[0] >>> image.save("cd_imagenet64_l2_onestep_sample.png") >>> # Onestep sampling, class-conditional image generation >>> # ImageNet-64 class label 145 corresponds to king penguins >>> image = pipe(num_inference_steps=1, class_labels=145).images[0] >>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png") >>> # Multistep sampling, class-conditional image generation >>> # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo: >>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77 >>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0] >>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png") ``` """ class ConsistencyModelPipeline(DiffusionPipeline): r""" Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1]. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" https://arxiv.org/pdf/2303.01469 Args: unet ([`UNet2DModel`]): Unconditional or class-conditional U-Net architecture to denoise image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the image latents. Currently only compatible with [`CMStochasticIterativeScheduler`]. """ def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None: super().__init__() self.register_modules( unet=unet, scheduler=scheduler, ) self.safety_checker = None def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Follows diffusers.VaeImageProcessor.postprocess def postprocess_image(self, sample: torch.FloatTensor, output_type: str = "pil"): if output_type not in ["pt", "np", "pil"]: raise ValueError( f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" ) # Equivalent to diffusers.VaeImageProcessor.denormalize sample = (sample / 2 + 0.5).clamp(0, 1) if output_type == "pt": return sample # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "np": return sample # Output_type must be 'pil' sample = self.numpy_to_pil(sample) return sample def prepare_class_labels(self, batch_size, device, class_labels=None): if self.unet.config.num_class_embeds is not None: if isinstance(class_labels, list): class_labels = torch.tensor(class_labels, dtype=torch.int) elif isinstance(class_labels, int): assert batch_size == 1, "Batch size must be 1 if classes is an int" class_labels = torch.tensor([class_labels], dtype=torch.int) elif class_labels is None: # Randomly generate batch_size class labels # TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,)) class_labels = class_labels.to(device) else: class_labels = None return class_labels def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps): if num_inference_steps is None and timesteps is None: raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") if num_inference_steps is not None and timesteps is not None: logger.warning( f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;" " `timesteps` will be used over `num_inference_steps`." ) if latents is not None: expected_shape = (batch_size, 3, img_size, img_size) if latents.shape != expected_shape: raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, batch_size: int = 1, class_labels: Optional[Union[torch.Tensor, List[int], int]] = None, num_inference_steps: int = 1, timesteps: List[int] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*): Optional class labels for conditioning class-conditional consistency models. Will not be used if the model is not class-conditional. num_inference_steps (`int`, *optional*, defaults to 1): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Prepare call parameters img_size = self.unet.config.sample_size device = self._execution_device # 1. Check inputs self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps) # 2. Prepare image latents # Sample image latents x_0 ~ N(0, sigma_0^2 * I) sample = self.prepare_latents( batch_size=batch_size, num_channels=self.unet.config.in_channels, height=img_size, width=img_size, dtype=self.unet.dtype, device=device, generator=generator, latents=latents, ) # 3. Handle class_labels for class-conditional models class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels) # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps # 5. Denoising loop # Multistep sampling: implements Algorithm 1 in the paper with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): scaled_sample = self.scheduler.scale_model_input(sample, t) model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0] sample = self.scheduler.step(model_output, t, sample, generator=generator)[0] # call the callback, if provided progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, sample) # 6. Post-process image sample image = self.postprocess_image(sample, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/controlnet/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_flax_available, is_torch_available, is_transformers_available, ) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .multicontrolnet import MultiControlNetModel from .pipeline_controlnet import StableDiffusionControlNetPipeline from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline if is_transformers_available() and is_flax_available(): from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline ================================================ FILE: 6DoF/diffusers/pipelines/controlnet/multicontrolnet.py ================================================ import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn from ...models.controlnet import ControlNetModel, ControlNetOutput from ...models.modeling_utils import ModelMixin from ...utils import logging logger = logging.get_logger(__name__) class MultiControlNetModel(ModelMixin): r""" Multiple `ControlNetModel` wrapper class for Multi-ControlNet This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be compatible with `ControlNetModel`. Args: controlnets (`List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. You must set multiple `ControlNetModel` as a list. """ def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): super().__init__() self.nets = nn.ModuleList(controlnets) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: List[torch.tensor], conditioning_scale: List[float], class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): down_samples, mid_sample = controlnet( sample, timestep, encoder_hidden_states, image, scale, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, guess_mode, return_dict, ) # merge samples if i == 0: down_block_res_samples, mid_block_res_sample = down_samples, mid_sample else: down_block_res_samples = [ samples_prev + samples_curr for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) ] mid_block_res_sample += mid_sample return down_block_res_samples, mid_block_res_sample def save_pretrained( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Callable = None, safe_serialization: bool = False, variant: Optional[str] = None, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replace `torch.save` by another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). variant (`str`, *optional*): If specified, weights are saved in the format pytorch_model..bin. """ idx = 0 model_path_to_save = save_directory for controlnet in self.nets: controlnet.save_pretrained( model_path_to_save, is_main_process=is_main_process, save_function=save_function, safe_serialization=safe_serialization, variant=variant, ) idx += 1 model_path_to_save = model_path_to_save + f"_{idx}" @classmethod def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train the model, you should first set it back in training mode with `model.train()`. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_path (`os.PathLike`): A path to a *directory* containing model weights saved using [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., `./my_model_directory/controlnet`. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype will be automatically derived from the model's weights. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. variant (`str`, *optional*): If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is ignored when using `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. """ idx = 0 controlnets = [] # load controlnet and append to list until no controlnet directory exists anymore # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... model_path_to_load = pretrained_model_path while os.path.isdir(model_path_to_load): controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) controlnets.append(controlnet) idx += 1 model_path_to_load = pretrained_model_path + f"_{idx}" logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") if len(controlnets) == 0: raise ValueError( f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." ) return cls(controlnets) ================================================ FILE: 6DoF/diffusers/pipelines/controlnet/pipeline_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, is_compiled_module, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" ... ) >>> image = np.array(image) >>> # get canny image >>> image = cv2.Canny(image, 100, 200) >>> image = image[:, :, None] >>> image = np.concatenate([image, image, image], axis=2) >>> canny_image = Image.fromarray(image) >>> # load control net and stable diffusion v1-5 >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> # speed up diffusion process with faster scheduler and memory optimization >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) >>> # remove following line if xformers is not installed >>> pipe.enable_xformers_memory_efficient_attention() >>> pipe.enable_model_cpu_offload() >>> # generate image >>> generator = torch.manual_seed(0) >>> image = pipe( ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image ... ).images[0] ``` """ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: # the safety checker can offload the vae again _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # control net hook has be manually offloaded as it alternates with unet cpu_offload_with_hook(self.controlnet, device) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the controlnet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the controlnet stops applying. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_end ] # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] elif isinstance(controlnet, MultiControlNetModel): images = [] for image_ in image: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) images.append(image_) image = images height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: cond_scale = controlnet_conditioning_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, is_accelerate_available, is_accelerate_version, is_compiled_module, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" ... ) >>> np_image = np.array(image) >>> # get canny image >>> np_image = cv2.Canny(np_image, 100, 200) >>> np_image = np_image[:, :, None] >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2) >>> canny_image = Image.fromarray(np_image) >>> # load control net and stable diffusion v1-5 >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> # speed up diffusion process with faster scheduler and memory optimization >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # generate image >>> generator = torch.manual_seed(0) >>> image = pipe( ... "futuristic-looking woman", ... num_inference_steps=20, ... generator=generator, ... image=image, ... control_image=canny_image, ... ).images[0] ``` """ def prepare_image(image): if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: image = image.unsqueeze(0) image = image.to(dtype=torch.float32) else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 return image class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: # the safety checker can offload the vae again _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # control net hook has be manually offloaded as it alternates with unet cpu_offload_with_hook(self.controlnet, device) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_control_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, control_image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The initial image will be used as the starting point for the image generation process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting than for [`~StableDiffusionControlNetPipeline.__call__`]. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the controlnet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the controlnet stops applying. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_end ] # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, control_image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare image image = self.image_processor.preprocess(image).to(dtype=torch.float32) # 5. Prepare controlnet_conditioning_image if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: cond_scale = controlnet_conditioning_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, is_compiled_module, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install transformers accelerate >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> init_image = load_image( ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" ... ) >>> init_image = init_image.resize((512, 512)) >>> generator = torch.Generator(device="cpu").manual_seed(1) >>> mask_image = load_image( ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" ... ) >>> mask_image = mask_image.resize((512, 512)) >>> def make_inpaint_condition(image, image_mask): ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" ... image[image_mask > 0.5] = -1.0 # set as masked pixel ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) ... image = torch.from_numpy(image) ... return image >>> control_image = make_inpaint_condition(init_image, mask_image) >>> controlnet = ControlNetModel.from_pretrained( ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 ... ) >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # generate image >>> image = pipe( ... "a handsome man with ray-ban sunglasses", ... num_inference_steps=20, ... generator=generator, ... eta=1.0, ... image=init_image, ... mask_image=mask_image, ... control_image=control_image, ... ).images[0] ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if image is None: raise ValueError("`image` input cannot be undefined.") if mask is None: raise ValueError("`mask_image` input cannot be undefined.") if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) masked_image = image * (mask < 0.5) # n.b. ensure backwards compatibility as old function does not return image if return_image: return mask, masked_image, image return mask, masked_image class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) as well as default text-to-image stable diffusion checkpoints, such as [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: # the safety checker can offload the vae again _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # control net hook has be manually offloaded as it alternates with unet cpu_offload_with_hook(self.controlnet, device) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def check_inputs( self, prompt, image, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if isinstance(self.controlnet, MultiControlNetModel): if len(control_guidance_start) != len(self.controlnet.nets): raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_control_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None, timestep=None, is_strength_max=True, return_noise=False, return_image_latents=False, ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if (image is None or timestep is None) and not is_strength_max: raise ValueError( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) if return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents else: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma outputs = (latents,) if return_noise: outputs += (noise,) if return_image_latents: outputs += (image_latents,) return outputs def _default_height_width(self, height, width, image): # NOTE: It is possible that a list of images have different # dimensions for each image, so just checking the first image # is not _exactly_ correct, but it is simple. while isinstance(image, list): image = image[0] if height is None: if isinstance(image, PIL.Image.Image): height = image.height elif isinstance(image, torch.Tensor): height = image.shape[2] height = (height // 8) * 8 # round down to nearest multiple of 8 if width is None: if isinstance(image, PIL.Image.Image): width = image.width elif isinstance(image, torch.Tensor): width = image.shape[3] width = (width // 8) * 8 # round down to nearest multiple of 8 return height, width # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = self.vae.config.scaling_factor * image_latents return image_latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[torch.Tensor, PIL.Image.Image] = None, mask_image: Union[torch.Tensor, PIL.Image.Image] = None, control_image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.5, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. strength (`float`, *optional*, defaults to 1.): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked portion of the reference `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting than for [`~StableDiffusionControlNetPipeline.__call__`]. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the controlnet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the controlnet stops applying. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # 0. Default height and width to unet height, width = self._default_height_width(height, width, image) # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_end ] # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, control_image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare image if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) elif isinstance(controlnet, MultiControlNetModel): control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images else: assert False # 4. Preprocess mask and image - resizes image and mask w.r.t height and width mask, masked_image, init_image = prepare_mask_and_masked_image( image, mask_image, height, width, return_image=True ) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device ) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, do_classifier_free_guidance, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: cond_scale = controlnet_conditioning_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=control_image, conditioning_scale=cond_scale, guess_mode=guess_mode, return_dict=False, ) if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual if num_channels_unet == 9: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: init_latents_proper = image_latents[:1] init_mask = mask[:1] if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") torch.cuda.empty_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from PIL import Image from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from ..stable_diffusion import FlaxStableDiffusionPipelineOutput from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> import jax.numpy as jnp >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> from diffusers.utils import load_image >>> from PIL import Image >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel >>> def image_grid(imgs, rows, cols): ... w, h = imgs[0].size ... grid = Image.new("RGB", size=(cols * w, rows * h)) ... for i, img in enumerate(imgs): ... grid.paste(img, box=(i % cols * w, i // cols * h)) ... return grid >>> def create_key(seed=0): ... return jax.random.PRNGKey(seed) >>> rng = create_key(0) >>> # get canny image >>> canny_image = load_image( ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" ... ) >>> prompts = "best quality, extremely detailed" >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" >>> # load control net and stable diffusion v1-5 >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 ... ) >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32 ... ) >>> params["controlnet"] = controlnet_params >>> num_samples = jax.device_count() >>> rng = jax.random.split(rng, jax.device_count()) >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) >>> p_params = replicate(params) >>> prompt_ids = shard(prompt_ids) >>> negative_prompt_ids = shard(negative_prompt_ids) >>> processed_image = shard(processed_image) >>> output = pipe( ... prompt_ids=prompt_ids, ... image=processed_image, ... params=p_params, ... prng_seed=rng, ... num_inference_steps=50, ... neg_prompt_ids=negative_prompt_ids, ... jit=True, ... ).images >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) >>> output_images = image_grid(output_images, num_samples // 4, 4) >>> output_images.save("generated_image.png") ``` """ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`FlaxCLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`FlaxControlNetModel`]: Provides additional conditioning to the unet during the denoising process. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, controlnet: FlaxControlNetModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warn( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_text_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): if not isinstance(image, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(image, Image.Image): image = [image] processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) return processed_images def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def _generate( self, prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int, guidance_scale: float, latents: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.array] = None, controlnet_conditioning_scale: float = 1.0, ): height, width = image.shape[-2:] if height % 64 != 0 or width % 64 != 0: raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) image = jnp.concatenate([image] * 2) latents_shape = ( batch_size, self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) down_block_res_samples, mid_block_res_sample = self.controlnet.apply( {"params": params["controlnet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, controlnet_cond=image, conditioning_scale=controlnet_conditioning_scale, return_dict=False, ) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int = 50, guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, neg_prompt_ids: jnp.array = None, controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, return_dict: bool = True, jit: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt_ids (`jnp.array`): The prompt or prompts to guide the image generation. image (`jnp.array`): Array representing the ControlNet input condition. ControlNet use this input condition to generate guidance to Unet. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. latents (`jnp.array`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ height, width = image.shape[-2:] if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] if isinstance(controlnet_conditioning_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] if jit: images = _p_generate( self, prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ) else: images = self._generate( prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.array(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, num_inference_steps. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), static_broadcasted_argnums=(0, 5), ) def _p_generate( pipe, prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ): return pipe._generate( prompt_ids, image, params, prng_seed, num_inference_steps, guidance_scale, latents, neg_prompt_ids, controlnet_conditioning_scale, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) def preprocess(image, dtype): image = image.convert("RGB") w, h = image.size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) return image ================================================ FILE: 6DoF/diffusers/pipelines/dance_diffusion/__init__.py ================================================ from .pipeline_dance_diffusion import DanceDiffusionPipeline ================================================ FILE: 6DoF/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...utils import logging, randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name class DanceDiffusionPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of [`IPNDMScheduler`]. """ def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 100, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, audio_length_in_s: Optional[float] = None, return_dict: bool = True, ) -> Union[AudioPipelineOutput, Tuple]: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of audio samples to generate. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality audio sample at the expense of slower inference. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. audio_length_in_s (`float`, *optional*, defaults to `self.unet.config.sample_size/self.unet.config.sample_rate`): The length of the generated audio sample in seconds. Note that the output of the pipeline, *i.e.* `sample_size`, will be `audio_length_in_s` * `self.unet.config.sample_rate`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.AudioPipelineOutput`] or `tuple`: [`~pipelines.utils.AudioPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if audio_length_in_s is None: audio_length_in_s = self.unet.config.sample_size / self.unet.config.sample_rate sample_size = audio_length_in_s * self.unet.config.sample_rate down_scale_factor = 2 ** len(self.unet.up_blocks) if sample_size < 3 * down_scale_factor: raise ValueError( f"{audio_length_in_s} is too small. Make sure it's bigger or equal to" f" {3 * down_scale_factor / self.unet.config.sample_rate}." ) original_sample_size = int(sample_size) if sample_size % down_scale_factor != 0: sample_size = ( (audio_length_in_s * self.unet.config.sample_rate) // down_scale_factor + 1 ) * down_scale_factor logger.info( f"{audio_length_in_s} is increased to {sample_size / self.unet.config.sample_rate} so that it can be handled" f" by the model. It will be cut to {original_sample_size / self.unet.config.sample_rate} after the denoising" " process." ) sample_size = int(sample_size) dtype = next(iter(self.unet.parameters())).dtype shape = (batch_size, self.unet.config.in_channels, sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype) # set step values self.scheduler.set_timesteps(num_inference_steps, device=audio.device) self.scheduler.timesteps = self.scheduler.timesteps.to(dtype) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(audio, t).sample # 2. compute previous image: x_t -> t_t-1 audio = self.scheduler.step(model_output, t, audio).prev_sample audio = audio.clamp(-1, 1).float().cpu().numpy() audio = audio[:, :, :original_sample_size] if not return_dict: return (audio,) return AudioPipelineOutput(audios=audio) ================================================ FILE: 6DoF/diffusers/pipelines/ddim/__init__.py ================================================ from .pipeline_ddim import DDIMPipeline ================================================ FILE: 6DoF/diffusers/pipelines/ddim/pipeline_ddim.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...schedulers import DDIMScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DDIMPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of [`DDPMScheduler`], or [`DDIMScheduler`]. """ def __init__(self, unet, scheduler): super().__init__() # make sure scheduler can always be converted to DDIM scheduler = DDIMScheduler.from_config(scheduler.config) self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, eta: float = 0.0, num_inference_steps: int = 50, use_clipped_model_output: Optional[bool] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. eta (`float`, *optional*, defaults to 0.0): The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. use_clipped_model_output (`bool`, *optional*, defaults to `None`): if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed downstream to the scheduler. So use `None` for schedulers which don't support this argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # Sample gaussian noise to begin loop if isinstance(self.unet.config.sample_size, int): image_shape = ( batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size, ) else: image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) # set step values self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t).sample # 2. predict previous mean of image x_t-1 and add variance depending on eta # eta corresponds to η in paper and should be between [0, 1] # do x_t -> x_t-1 image = self.scheduler.step( model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator ).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/ddpm/__init__.py ================================================ from .pipeline_ddpm import DDPMPipeline ================================================ FILE: 6DoF/diffusers/pipelines/ddpm/pipeline_ddpm.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DDPMPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of [`DDPMScheduler`], or [`DDIMScheduler`]. """ def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, num_inference_steps: int = 1000, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 1000): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # Sample gaussian noise to begin loop if isinstance(self.unet.config.sample_size, int): image_shape = ( batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size, ) else: image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) if self.device.type == "mps": # randn does not work reproducibly on mps image = randn_tensor(image_shape, generator=generator) image = image.to(self.device) else: image = randn_tensor(image_shape, generator=generator, device=self.device) # set step values self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/deepfloyd_if/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available from .timesteps import ( fast27_timesteps, smart27_timesteps, smart50_timesteps, smart100_timesteps, smart185_timesteps, super27_timesteps, super40_timesteps, super100_timesteps, ) @dataclass class IFPipelineOutput(BaseOutput): """ Args: Output class for Stable Diffusion pipelines. images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content or a watermark. `None` if safety checking could not be performed. watermark_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_detected: Optional[List[bool]] watermark_detected: Optional[List[bool]] try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_if import IFPipeline from .pipeline_if_img2img import IFImg2ImgPipeline from .pipeline_if_img2img_superresolution import IFImg2ImgSuperResolutionPipeline from .pipeline_if_inpainting import IFInpaintingPipeline from .pipeline_if_inpainting_superresolution import IFInpaintingSuperResolutionPipeline from .pipeline_if_superresolution import IFSuperResolutionPipeline from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker ================================================ FILE: 6DoF/diffusers/pipelines/deepfloyd_if/pipeline_if.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) >>> pipe.enable_model_cpu_offload() >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt" ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> safety_modules = { ... "feature_extractor": pipe.feature_extractor, ... "safety_checker": pipe.safety_checker, ... "watermarker": pipe.watermarker, ... } >>> super_res_2_pipe = DiffusionPipeline.from_pretrained( ... "stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16 ... ) >>> super_res_2_pipe.enable_model_cpu_offload() >>> image = super_res_2_pipe( ... prompt=prompt, ... image=image, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler intermediate_images = intermediate_images * self.scheduler.init_noise_sigma return intermediate_images def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, num_inference_steps: int = 100, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, width: Optional[int] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated image. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters height = height or self.unet.config.sample_size width = width or self.unet.config.sample_size if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare intermediate images intermediate_images = self.prepare_intermediate_images( batch_size * num_images_per_prompt, self.unet.config.in_channels, height, width, prompt_embeds.dtype, device, generator, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = ( torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images ) model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL image = self.numpy_to_pil(image) # 11. Apply watermark if self.watermarker is not None: image = self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 6DoF/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image.resize((768, 512)) >>> pipe = IFImg2ImgPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "A fantasy landscape in style minecraft" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", ... text_encoder=None, ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None ): _, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) image = self.scheduler.add_noise(image, noise, timestep) return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 0.7, num_inference_steps: int = 80, timesteps: List[int] = None, guidance_scale: float = 10.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. Prepare intermediate images image = self.preprocess_image(image) image = image.to(device=device, dtype=dtype) noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = ( torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images ) model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL image = self.numpy_to_pil(image) # 11. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 6DoF/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image.resize((768, 512)) >>> pipe = IFImg2ImgPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "A fantasy landscape in style minecraft" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", ... text_encoder=None, ... variant="fp16", ... torch_dtype=torch.float16, ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler image_noising_scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, image_noising_scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if unet.config.in_channels != 6: logger.warn( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, image_noising_scheduler=image_noising_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, original_image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # image if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # original_image if isinstance(original_image, list): check_image_type = original_image[0] else: check_image_type = original_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(original_image, list): image_batch_size = len(original_image) elif isinstance(original_image, torch.Tensor): image_batch_size = original_image.shape[0] elif isinstance(original_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(original_image, np.ndarray): image_batch_size = original_image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError( f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: if not isinstance(image, torch.Tensor) and not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: image = image[0] image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image, list) and isinstance(image[0], torch.Tensor): dims = image[0].ndim if dims == 3: image = torch.stack(image, dim=0) elif dims == 4: image = torch.concat(image, dim=0) else: raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") image = image.to(device=device, dtype=self.unet.dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.prepare_intermediate_images def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None ): _, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) image = self.scheduler.add_noise(image, noise, timestep) return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor], original_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 0.8, prompt: Union[str, List[str]] = None, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 250, clean_caption: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. original_image (`torch.FloatTensor` or `PIL.Image.Image`): The original image that `image` was varied from. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to 250): The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, original_image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 device = self._execution_device # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. prepare original image original_image = self.preprocess_original_image(original_image) original_image = original_image.to(device=device, dtype=dtype) # 6. Prepare intermediate images noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( original_image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, generator, ) # 7. Prepare upscaled image and noise level _, _, height, width = original_image.shape image = self.preprocess_image(image, num_images_per_prompt, device) upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) if do_classifier_free_guidance: noise_level = torch.cat([noise_level] * 2) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = torch.cat([intermediate_images, upscaled], dim=1) model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 12. Convert to PIL image = self.numpy_to_pil(image) # 13. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 6DoF/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" >>> response = requests.get(url) >>> mask_image = Image.open(BytesIO(response.content)) >>> mask_image = mask_image >>> pipe = IFInpaintingPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "blue sunglasses" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... mask_image=mask_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... mask_image=mask_image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, mask_image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # image if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # mask_image if isinstance(mask_image, list): check_image_type = mask_image[0] else: check_image_type = mask_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(mask_image, list): image_batch_size = len(mask_image) elif isinstance(mask_image, torch.Tensor): image_batch_size = mask_image.shape[0] elif isinstance(mask_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(mask_image, np.ndarray): image_batch_size = mask_image.shape[0] else: assert False if image_batch_size != 1 and batch_size != image_batch_size: raise ValueError( f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image def preprocess_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image def preprocess_mask_image(self, mask_image) -> torch.Tensor: if not isinstance(mask_image, list): mask_image = [mask_image] if isinstance(mask_image[0], torch.Tensor): mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) if mask_image.ndim == 2: # Batch and add channel dim for single mask mask_image = mask_image.unsqueeze(0).unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] == 1: # Single mask, the 0'th dimension is considered to be # the existing batch size of 1 mask_image = mask_image.unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] != 1: # Batch of mask, the 0'th dimension is considered to be # the batching dimension mask_image = mask_image.unsqueeze(1) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 elif isinstance(mask_image[0], PIL.Image.Image): new_mask_image = [] for mask_image_ in mask_image: mask_image_ = mask_image_.convert("L") mask_image_ = resize(mask_image_, self.unet.sample_size) mask_image_ = np.array(mask_image_) mask_image_ = mask_image_[None, None, :] new_mask_image.append(mask_image_) mask_image = new_mask_image mask_image = np.concatenate(mask_image, axis=0) mask_image = mask_image.astype(np.float32) / 255.0 mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) elif isinstance(mask_image[0], np.ndarray): mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) return mask_image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None ): image_batch_size, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) noised_image = self.scheduler.add_noise(image, noise, timestep) image = (1 - mask_image) * image + mask_image * noised_image return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, mask_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 1.0, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 7.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, clean_caption: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, mask_image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. Prepare intermediate images image = self.preprocess_image(image) image = image.to(device=device, dtype=dtype) mask_image = self.preprocess_mask_image(mask_image) mask_image = mask_image.to(device=device, dtype=dtype) if mask_image.shape[0] == 1: mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) else: mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = ( torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images ) model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 prev_intermediate_images = intermediate_images intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 10. Convert to PIL image = self.numpy_to_pil(image) # 11. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 8. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 9. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 6DoF/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.resize def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: w, h = images.size coef = w / h w, h = img_size, img_size if coef >= 1: w = int(round(img_size / 8 * coef) * 8) else: h = int(round(img_size / 8 / coef) * 8) images = images.resize((w, h), resample=PIL_INTERPOLATION["bicubic"], reducing_gap=None) return images EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFInpaintingPipeline, IFInpaintingSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> from PIL import Image >>> import requests >>> from io import BytesIO >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/person.png" >>> response = requests.get(url) >>> original_image = Image.open(BytesIO(response.content)).convert("RGB") >>> original_image = original_image >>> url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/if/glasses_mask.png" >>> response = requests.get(url) >>> mask_image = Image.open(BytesIO(response.content)) >>> mask_image = mask_image >>> pipe = IFInpaintingPipeline.from_pretrained( ... "DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16 ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "blue sunglasses" >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe( ... image=original_image, ... mask_image=mask_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... output_type="pt", ... ).images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFInpaintingSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, ... mask_image=mask_image, ... original_image=original_image, ... prompt_embeds=prompt_embeds, ... negative_prompt_embeds=negative_embeds, ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler image_noising_scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, image_noising_scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if unet.config.in_channels != 6: logger.warn( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, image_noising_scheduler=image_noising_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, original_image, mask_image, batch_size, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # image if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # original_image if isinstance(original_image, list): check_image_type = original_image[0] else: check_image_type = original_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`original_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(original_image, list): image_batch_size = len(original_image) elif isinstance(original_image, torch.Tensor): image_batch_size = original_image.shape[0] elif isinstance(original_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(original_image, np.ndarray): image_batch_size = original_image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError( f"original_image batch size: {image_batch_size} must be same as prompt batch size {batch_size}" ) # mask_image if isinstance(mask_image, list): check_image_type = mask_image[0] else: check_image_type = mask_image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`mask_image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(mask_image, list): image_batch_size = len(mask_image) elif isinstance(mask_image, torch.Tensor): image_batch_size = mask_image.shape[0] elif isinstance(mask_image, PIL.Image.Image): image_batch_size = 1 elif isinstance(mask_image, np.ndarray): image_batch_size = mask_image.shape[0] else: assert False if image_batch_size != 1 and batch_size != image_batch_size: raise ValueError( f"mask_image batch size: {image_batch_size} must be `1` or the same as prompt batch size {batch_size}" ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.preprocess_image with preprocess_image -> preprocess_original_image def preprocess_original_image(self, image: PIL.Image.Image) -> torch.Tensor: if not isinstance(image, list): image = [image] def numpy_to_pt(images): if images.ndim == 3: images = images[..., None] images = torch.from_numpy(images.transpose(0, 3, 1, 2)) return images if isinstance(image[0], PIL.Image.Image): new_image = [] for image_ in image: image_ = image_.convert("RGB") image_ = resize(image_, self.unet.sample_size) image_ = np.array(image_) image_ = image_.astype(np.float32) image_ = image_ / 127.5 - 1 new_image.append(image_) image = new_image image = np.stack(image, axis=0) # to np image = numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) image = numpy_to_pt(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_superresolution.IFSuperResolutionPipeline.preprocess_image def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device) -> torch.Tensor: if not isinstance(image, torch.Tensor) and not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: image = image[0] image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image, list) and isinstance(image[0], torch.Tensor): dims = image[0].ndim if dims == 3: image = torch.stack(image, dim=0) elif dims == 4: image = torch.concat(image, dim=0) else: raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") image = image.to(device=device, dtype=self.unet.dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) return image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.preprocess_mask_image def preprocess_mask_image(self, mask_image) -> torch.Tensor: if not isinstance(mask_image, list): mask_image = [mask_image] if isinstance(mask_image[0], torch.Tensor): mask_image = torch.cat(mask_image, axis=0) if mask_image[0].ndim == 4 else torch.stack(mask_image, axis=0) if mask_image.ndim == 2: # Batch and add channel dim for single mask mask_image = mask_image.unsqueeze(0).unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] == 1: # Single mask, the 0'th dimension is considered to be # the existing batch size of 1 mask_image = mask_image.unsqueeze(0) elif mask_image.ndim == 3 and mask_image.shape[0] != 1: # Batch of mask, the 0'th dimension is considered to be # the batching dimension mask_image = mask_image.unsqueeze(1) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 elif isinstance(mask_image[0], PIL.Image.Image): new_mask_image = [] for mask_image_ in mask_image: mask_image_ = mask_image_.convert("L") mask_image_ = resize(mask_image_, self.unet.sample_size) mask_image_ = np.array(mask_image_) mask_image_ = mask_image_[None, None, :] new_mask_image.append(mask_image_) mask_image = new_mask_image mask_image = np.concatenate(mask_image, axis=0) mask_image = mask_image.astype(np.float32) / 255.0 mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) elif isinstance(mask_image[0], np.ndarray): mask_image = np.concatenate([m[None, None, :] for m in mask_image], axis=0) mask_image[mask_image < 0.5] = 0 mask_image[mask_image >= 0.5] = 1 mask_image = torch.from_numpy(mask_image) return mask_image # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_inpainting.IFInpaintingPipeline.prepare_intermediate_images def prepare_intermediate_images( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator=None ): image_batch_size, channels, height, width = image.shape batch_size = batch_size * num_images_per_prompt shape = (batch_size, channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) noised_image = self.scheduler.add_noise(image, noise, timestep) image = (1 - mask_image) * image + mask_image * noised_image return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor], original_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, mask_image: Union[ PIL.Image.Image, torch.Tensor, np.ndarray, List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray] ] = None, strength: float = 0.8, prompt: Union[str, List[str]] = None, num_inference_steps: int = 100, timesteps: List[int] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 0, clean_caption: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. original_image (`torch.FloatTensor` or `PIL.Image.Image`): The original image that `image` was varied from. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to 0): The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, original_image, mask_image, batch_size, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 device = self._execution_device # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) dtype = prompt_embeds.dtype # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) # 5. prepare original image original_image = self.preprocess_original_image(original_image) original_image = original_image.to(device=device, dtype=dtype) # 6. prepare mask image mask_image = self.preprocess_mask_image(mask_image) mask_image = mask_image.to(device=device, dtype=dtype) if mask_image.shape[0] == 1: mask_image = mask_image.repeat_interleave(batch_size * num_images_per_prompt, dim=0) else: mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) # 6. Prepare intermediate images noise_timestep = timesteps[0:1] noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) intermediate_images = self.prepare_intermediate_images( original_image, noise_timestep, batch_size, num_images_per_prompt, dtype, device, mask_image, generator, ) # 7. Prepare upscaled image and noise level _, _, height, width = original_image.shape image = self.preprocess_image(image, num_images_per_prompt, device) upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) if do_classifier_free_guidance: noise_level = torch.cat([noise_level] * 2) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = torch.cat([intermediate_images, upscaled], dim=1) model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 prev_intermediate_images = intermediate_images intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 12. Convert to PIL image = self.numpy_to_pil(image) # 13. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 10. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 11. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 6DoF/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py ================================================ import html import inspect import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( BACKENDS_MAPPING, is_accelerate_available, is_accelerate_version, is_bs4_available, is_ftfy_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import IFPipelineOutput from .safety_checker import IFSafetyChecker from .watermark import IFWatermarker if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import IFPipeline, IFSuperResolutionPipeline, DiffusionPipeline >>> from diffusers.utils import pt_to_pil >>> import torch >>> pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) >>> pipe.enable_model_cpu_offload() >>> prompt = 'a photo of a kangaroo wearing an orange hoodie and blue sunglasses standing in front of the eiffel tower holding a sign that says "very deep learning"' >>> prompt_embeds, negative_embeds = pipe.encode_prompt(prompt) >>> image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, output_type="pt").images >>> # save intermediate image >>> pil_image = pt_to_pil(image) >>> pil_image[0].save("./if_stage_I.png") >>> super_res_1_pipe = IFSuperResolutionPipeline.from_pretrained( ... "DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", torch_dtype=torch.float16 ... ) >>> super_res_1_pipe.enable_model_cpu_offload() >>> image = super_res_1_pipe( ... image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds ... ).images >>> image[0].save("./if_stage_II.png") ``` """ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel unet: UNet2DConditionModel scheduler: DDPMScheduler image_noising_scheduler: DDPMScheduler feature_extractor: Optional[CLIPImageProcessor] safety_checker: Optional[IFSafetyChecker] watermarker: Optional[IFWatermarker] bad_punct_regex = re.compile( r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" ) # noqa _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"] def __init__( self, tokenizer: T5Tokenizer, text_encoder: T5EncoderModel, unet: UNet2DConditionModel, scheduler: DDPMScheduler, image_noising_scheduler: DDPMScheduler, safety_checker: Optional[IFSafetyChecker], feature_extractor: Optional[CLIPImageProcessor], watermarker: Optional[IFWatermarker], requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the IF license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if unet.config.in_channels != 6: logger.warn( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, image_noising_scheduler=image_noising_scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, watermarker=watermarker, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.text_encoder, self.unet, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None if self.text_encoder is not None: _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook) # Accelerate will move the next model to the device _before_ calling the offload hook of the # previous model. This will cause both models to be present on the device at the same time. # IF uses T5 for its text encoder which is really large. We can manually call the offload # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to # the GPU. self.text_encoder_offload_hook = hook _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook) # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet self.unet_offload_hook = hook if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks def remove_all_hooks(self): if is_accelerate_available(): from accelerate.hooks import remove_hook_from_module else: raise ImportError("Please install accelerate via `pip install accelerate`") for model in [self.text_encoder, self.unet, self.safety_checker]: if model is not None: remove_hook_from_module(model, recurse=True) self.unet_offload_hook = None self.text_encoder_offload_hook = None self.final_offload_hook = None # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing def _text_preprocessing(self, text, clean_caption=False): if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if clean_caption and not is_ftfy_available(): logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") clean_caption = False if not isinstance(text, (tuple, list)): text = [text] def process(text: str): if clean_caption: text = self._clean_caption(text) text = self._clean_caption(text) else: text = text.lower().strip() return text return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption def _clean_caption(self, caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub("", "person", caption) # urls: caption = re.sub( r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls caption = re.sub( r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa "", caption, ) # regex for urls # html: caption = BeautifulSoup(caption, features="html.parser").text # @ caption = re.sub(r"@[\w\d]+\b", "", caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) caption = re.sub(r"[\u3200-\u32ff]+", "", caption) caption = re.sub(r"[\u3300-\u33ff]+", "", caption) caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa "-", caption, ) # кавычки к одному стандарту caption = re.sub(r"[`´«»“”¨]", '"', caption) caption = re.sub(r"[‘’]", "'", caption) # " caption = re.sub(r""?", "", caption) # & caption = re.sub(r"&", "", caption) # ip adresses: caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) # article ids: caption = re.sub(r"\d:\d\d\s+$", "", caption) # \n caption = re.sub(r"\\n", " ", caption) # "#123" caption = re.sub(r"#\d{1,3}\b", "", caption) # "#12345.." caption = re.sub(r"#\d{5,}\b", "", caption) # "123456.." caption = re.sub(r"\b\d{6,}\b", "", caption) # filenames: caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) # caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r"(?:\-|\_)") if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, " ", caption) caption = ftfy.fix_text(caption) caption = html.unescape(html.unescape(caption)) caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) caption = re.sub(r"\bpage\s+\d+\b", "", caption) caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) caption = re.sub(r"\b\s+\:\s+", r": ", caption) caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) caption = re.sub(r"\s+", " ", caption) caption.strip() caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) caption = re.sub(r"^\.\S+$", "", caption) return caption.strip() @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt def encode_prompt( self, prompt, do_classifier_free_guidance=True, num_images_per_prompt=1, device=None, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`, *optional*): torch device to place the resulting embeddings on num_images_per_prompt (`int`, *optional*, defaults to 1): number of images that should be generated per prompt do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if device is None: device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF max_length = 77 if prompt_embeds is None: prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {max_length} tokens: {removed_text}" ) attention_mask = text_inputs.attention_mask.to(device) prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] if self.text_encoder is not None: dtype = self.text_encoder.dtype elif self.unet is not None: dtype = self.unet.dtype else: dtype = None prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) attention_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes else: negative_prompt_embeds = None return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, batch_size, noise_level, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})" ) if isinstance(image, list): check_image_type = image[0] else: check_image_type = image if ( not isinstance(check_image_type, torch.Tensor) and not isinstance(check_image_type, PIL.Image.Image) and not isinstance(check_image_type, np.ndarray) ): raise ValueError( "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is" f" {type(check_image_type)}" ) if isinstance(image, list): image_batch_size = len(image) elif isinstance(image, torch.Tensor): image_batch_size = image.shape[0] elif isinstance(image, PIL.Image.Image): image_batch_size = 1 elif isinstance(image, np.ndarray): image_batch_size = image.shape[0] else: assert False if batch_size != image_batch_size: raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}") # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_intermediate_images def prepare_intermediate_images(self, batch_size, num_channels, height, width, dtype, device, generator): shape = (batch_size, num_channels, height, width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler intermediate_images = intermediate_images * self.scheduler.init_noise_sigma return intermediate_images def preprocess_image(self, image, num_images_per_prompt, device): if not isinstance(image, torch.Tensor) and not isinstance(image, list): image = [image] if isinstance(image[0], PIL.Image.Image): image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: image = image[0] image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image, list) and isinstance(image[0], torch.Tensor): dims = image[0].ndim if dims == 3: image = torch.stack(image, dim=0) elif dims == 4: image = torch.concat(image, dim=0) else: raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}") image = image.to(device=device, dtype=self.unet.dtype) image = image.repeat_interleave(num_images_per_prompt, dim=0) return image @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: int = None, width: int = None, image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, num_inference_steps: int = 50, timesteps: List[int] = None, guidance_scale: float = 4.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 250, clean_caption: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size): The width in pixels of the generated image. image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`): The image to be upscaled. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` timesteps are used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to 250): The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)` clean_caption (`bool`, *optional*, defaults to `True`): Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. Examples: Returns: [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) or watermarked content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] self.check_inputs( prompt, image, batch_size, noise_level, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters height = height or self.unet.config.sample_size width = width or self.unet.config.sample_size device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt, device=device, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, clean_caption=clean_caption, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps if timesteps is not None: self.scheduler.set_timesteps(timesteps=timesteps, device=device) timesteps = self.scheduler.timesteps num_inference_steps = len(timesteps) else: self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare intermediate images num_channels = self.unet.config.in_channels // 2 intermediate_images = self.prepare_intermediate_images( batch_size * num_images_per_prompt, num_channels, height, width, prompt_embeds.dtype, device, generator, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare upscaled image and noise level image = self.preprocess_image(image, num_images_per_prompt, device) upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True) noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device) noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype) upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level) if do_classifier_free_guidance: noise_level = torch.cat([noise_level] * 2) # HACK: see comment in `enable_model_cpu_offload` if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: self.text_encoder_offload_hook.offload() # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): model_input = torch.cat([intermediate_images, upscaled], dim=1) model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input model_input = self.scheduler.scale_model_input(model_input, t) # predict the noise residual noise_pred = self.unet( model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if self.scheduler.config.variance_type not in ["learned", "learned_range"]: noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) image = intermediate_images if output_type == "pil": # 9. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 10. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # 11. Convert to PIL image = self.numpy_to_pil(image) # 12. Apply watermark if self.watermarker is not None: self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() else: # 9. Post-processing image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() # 10. Run safety checker image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, nsfw_detected, watermark_detected) return IFPipelineOutput(images=image, nsfw_detected=nsfw_detected, watermark_detected=watermark_detected) ================================================ FILE: 6DoF/diffusers/pipelines/deepfloyd_if/safety_checker.py ================================================ import numpy as np import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) class IFSafetyChecker(PreTrainedModel): config_class = CLIPConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPConfig): super().__init__(config) self.vision_model = CLIPVisionModelWithProjection(config.vision_config) self.p_head = nn.Linear(config.vision_config.projection_dim, 1) self.w_head = nn.Linear(config.vision_config.projection_dim, 1) @torch.no_grad() def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5): image_embeds = self.vision_model(clip_input)[0] nsfw_detected = self.p_head(image_embeds) nsfw_detected = nsfw_detected.flatten() nsfw_detected = nsfw_detected > p_threshold nsfw_detected = nsfw_detected.tolist() if any(nsfw_detected): logger.warning( "Potential NSFW content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) for idx, nsfw_detected_ in enumerate(nsfw_detected): if nsfw_detected_: images[idx] = np.zeros(images[idx].shape) watermark_detected = self.w_head(image_embeds) watermark_detected = watermark_detected.flatten() watermark_detected = watermark_detected > w_threshold watermark_detected = watermark_detected.tolist() if any(watermark_detected): logger.warning( "Potential watermarked content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) for idx, watermark_detected_ in enumerate(watermark_detected): if watermark_detected_: images[idx] = np.zeros(images[idx].shape) return images, nsfw_detected, watermark_detected ================================================ FILE: 6DoF/diffusers/pipelines/deepfloyd_if/timesteps.py ================================================ fast27_timesteps = [ 999, 800, 799, 600, 599, 500, 400, 399, 377, 355, 333, 311, 288, 266, 244, 222, 200, 199, 177, 155, 133, 111, 88, 66, 44, 22, 0, ] smart27_timesteps = [ 999, 976, 952, 928, 905, 882, 858, 857, 810, 762, 715, 714, 572, 429, 428, 286, 285, 238, 190, 143, 142, 118, 95, 71, 47, 24, 0, ] smart50_timesteps = [ 999, 988, 977, 966, 955, 944, 933, 922, 911, 900, 899, 879, 859, 840, 820, 800, 799, 766, 733, 700, 699, 650, 600, 599, 500, 499, 400, 399, 350, 300, 299, 266, 233, 200, 199, 179, 159, 140, 120, 100, 99, 88, 77, 66, 55, 44, 33, 22, 11, 0, ] smart100_timesteps = [ 999, 995, 992, 989, 985, 981, 978, 975, 971, 967, 964, 961, 957, 956, 951, 947, 942, 937, 933, 928, 923, 919, 914, 913, 908, 903, 897, 892, 887, 881, 876, 871, 870, 864, 858, 852, 846, 840, 834, 828, 827, 820, 813, 806, 799, 792, 785, 784, 777, 770, 763, 756, 749, 742, 741, 733, 724, 716, 707, 699, 698, 688, 677, 666, 656, 655, 645, 634, 623, 613, 612, 598, 584, 570, 569, 555, 541, 527, 526, 505, 484, 483, 462, 440, 439, 396, 395, 352, 351, 308, 307, 264, 263, 220, 219, 176, 132, 88, 44, 0, ] smart185_timesteps = [ 999, 997, 995, 992, 990, 988, 986, 984, 981, 979, 977, 975, 972, 970, 968, 966, 964, 961, 959, 957, 956, 954, 951, 949, 946, 944, 941, 939, 936, 934, 931, 929, 926, 924, 921, 919, 916, 914, 913, 910, 907, 905, 902, 899, 896, 893, 891, 888, 885, 882, 879, 877, 874, 871, 870, 867, 864, 861, 858, 855, 852, 849, 846, 843, 840, 837, 834, 831, 828, 827, 824, 821, 817, 814, 811, 808, 804, 801, 798, 795, 791, 788, 785, 784, 780, 777, 774, 770, 766, 763, 760, 756, 752, 749, 746, 742, 741, 737, 733, 730, 726, 722, 718, 714, 710, 707, 703, 699, 698, 694, 690, 685, 681, 677, 673, 669, 664, 660, 656, 655, 650, 646, 641, 636, 632, 627, 622, 618, 613, 612, 607, 602, 596, 591, 586, 580, 575, 570, 569, 563, 557, 551, 545, 539, 533, 527, 526, 519, 512, 505, 498, 491, 484, 483, 474, 466, 457, 449, 440, 439, 428, 418, 407, 396, 395, 381, 366, 352, 351, 330, 308, 307, 286, 264, 263, 242, 220, 219, 176, 175, 132, 131, 88, 44, 0, ] super27_timesteps = [ 999, 991, 982, 974, 966, 958, 950, 941, 933, 925, 916, 908, 900, 899, 874, 850, 825, 800, 799, 700, 600, 500, 400, 300, 200, 100, 0, ] super40_timesteps = [ 999, 992, 985, 978, 971, 964, 957, 949, 942, 935, 928, 921, 914, 907, 900, 899, 879, 859, 840, 820, 800, 799, 766, 733, 700, 699, 650, 600, 599, 500, 499, 400, 399, 300, 299, 200, 199, 100, 99, 0, ] super100_timesteps = [ 999, 996, 992, 989, 985, 982, 979, 975, 972, 968, 965, 961, 958, 955, 951, 948, 944, 941, 938, 934, 931, 927, 924, 920, 917, 914, 910, 907, 903, 900, 899, 891, 884, 876, 869, 861, 853, 846, 838, 830, 823, 815, 808, 800, 799, 788, 777, 766, 755, 744, 733, 722, 711, 700, 699, 688, 677, 666, 655, 644, 633, 622, 611, 600, 599, 585, 571, 557, 542, 528, 514, 500, 499, 485, 471, 457, 442, 428, 414, 400, 399, 379, 359, 340, 320, 300, 299, 279, 259, 240, 220, 200, 199, 166, 133, 100, 99, 66, 33, 0, ] ================================================ FILE: 6DoF/diffusers/pipelines/deepfloyd_if/watermark.py ================================================ from typing import List import PIL import torch from PIL import Image from ...configuration_utils import ConfigMixin from ...models.modeling_utils import ModelMixin from ...utils import PIL_INTERPOLATION class IFWatermarker(ModelMixin, ConfigMixin): def __init__(self): super().__init__() self.register_buffer("watermark_image", torch.zeros((62, 62, 4))) self.watermark_image_as_pil = None def apply_watermark(self, images: List[PIL.Image.Image], sample_size=None): # copied from https://github.com/deep-floyd/IF/blob/b77482e36ca2031cb94dbca1001fc1e6400bf4ab/deepfloyd_if/modules/base.py#L287 h = images[0].height w = images[0].width sample_size = sample_size or h coef = min(h / sample_size, w / sample_size) img_h, img_w = (int(h / coef), int(w / coef)) if coef < 1 else (h, w) S1, S2 = 1024**2, img_w * img_h K = (S2 / S1) ** 0.5 wm_size, wm_x, wm_y = int(K * 62), img_w - int(14 * K), img_h - int(14 * K) if self.watermark_image_as_pil is None: watermark_image = self.watermark_image.to(torch.uint8).cpu().numpy() watermark_image = Image.fromarray(watermark_image, mode="RGBA") self.watermark_image_as_pil = watermark_image wm_img = self.watermark_image_as_pil.resize( (wm_size, wm_size), PIL_INTERPOLATION["bicubic"], reducing_gap=None ) for pil_img in images: pil_img.paste(wm_img, box=(wm_x - wm_size, wm_y - wm_size, wm_x, wm_y), mask=wm_img.split()[-1]) return images ================================================ FILE: 6DoF/diffusers/pipelines/dit/__init__.py ================================================ from .pipeline_dit import DiTPipeline ================================================ FILE: 6DoF/diffusers/pipelines/dit/pipeline_dit.py ================================================ # Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) # William Peebles and Saining Xie # # Copyright (c) 2021 OpenAI # MIT License # # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Optional, Tuple, Union import torch from ...models import AutoencoderKL, Transformer2DModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DiTPipeline(DiffusionPipeline): r""" This pipeline inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: transformer ([`Transformer2DModel`]): Class conditioned Transformer in Diffusion model to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `dit` to denoise the encoded image latents. """ def __init__( self, transformer: Transformer2DModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, id2label: Optional[Dict[int, str]] = None, ): super().__init__() self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) # create a imagenet -> id dictionary for easier use self.labels = {} if id2label is not None: for key, value in id2label.items(): for label in value.split(","): self.labels[label.lstrip().rstrip()] = int(key) self.labels = dict(sorted(self.labels.items())) def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: r""" Map label strings, *e.g.* from ImageNet, to corresponding class ids. Parameters: label (`str` or `dict` of `str`): label strings to be mapped to class ids. Returns: `list` of `int`: Class ids to be processed by pipeline. """ if not isinstance(label, list): label = list(label) for l in label: if l not in self.labels: raise ValueError( f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." ) return [self.labels[l] for l in label] @torch.no_grad() def __call__( self, class_labels: List[int], guidance_scale: float = 4.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, num_inference_steps: int = 50, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" Function invoked when calling the pipeline for generation. Args: class_labels (List[int]): List of imagenet class labels for the images to be generated. guidance_scale (`float`, *optional*, defaults to 4.0): Scale of the guidance signal. generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 250): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. """ batch_size = len(class_labels) latent_size = self.transformer.config.sample_size latent_channels = self.transformer.config.in_channels latents = randn_tensor( shape=(batch_size, latent_channels, latent_size, latent_size), generator=generator, device=self.device, dtype=self.transformer.dtype, ) latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents class_labels = torch.tensor(class_labels, device=self.device).reshape(-1) class_null = torch.tensor([1000] * batch_size, device=self.device) class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels # set step values self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): if guidance_scale > 1: half = latent_model_input[: len(latent_model_input) // 2] latent_model_input = torch.cat([half, half], dim=0) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) timesteps = t if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = latent_model_input.device.type == "mps" if isinstance(timesteps, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(latent_model_input.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(latent_model_input.shape[0]) # predict noise model_output noise_pred = self.transformer( latent_model_input, timestep=timesteps, class_labels=class_labels_input ).sample # perform guidance if guidance_scale > 1: eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) noise_pred = torch.cat([eps, rest], dim=1) # learned sigma if self.transformer.config.out_channels // 2 == latent_channels: model_output, _ = torch.split(noise_pred, latent_channels, dim=1) else: model_output = noise_pred # compute previous image: x_t -> x_t-1 latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample if guidance_scale > 1: latents, _ = latent_model_input.chunk(2, dim=0) else: latents = latent_model_input latents = 1 / self.vae.config.scaling_factor * latents samples = self.vae.decode(latents).sample samples = (samples / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": samples = self.numpy_to_pil(samples) if not return_dict: return (samples,) return ImagePipelineOutput(images=samples) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import KandinskyPipeline, KandinskyPriorPipeline else: from .pipeline_kandinsky import KandinskyPipeline from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline from .pipeline_kandinsky_prior import KandinskyPriorPipeline, KandinskyPriorPipelineOutput from .text_encoder import MultilingualCLIP ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky/pipeline_kandinsky.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import torch from transformers import ( XLMRobertaTokenizer, ) from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDIMScheduler, DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from .text_encoder import MultilingualCLIP logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/Kandinsky-2-1-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> out = pipe_prior(prompt) >>> image_emb = out.image_embeds >>> negative_image_emb = out.negative_image_embeds >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") >>> pipe.to("cuda") >>> image = pipe( ... prompt, ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images >>> image[0].save("cat.png") ``` """ def get_new_h_w(h, w, scale_factor=8): new_h = h // scale_factor**2 if h % scale_factor**2 != 0: new_h += 1 new_w = w // scale_factor**2 if w % scale_factor**2 != 0: new_w += 1 return new_h * scale_factor, new_w * scale_factor class KandinskyPipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, text_encoder: MultilingualCLIP, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, DDPMScheduler], movq: VQModel, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", truncation=True, max_length=77, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) text_mask = text_inputs.attention_mask.to(device) prompt_embeds, text_encoder_hidden_states = self.text_encoder( input_ids=text_input_ids, attention_mask=text_mask ) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) uncond_text_input_ids = uncond_input.input_ids.to(device) uncond_text_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.text_encoder, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=prompt_embeds.dtype, device=device ) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.unet.config.in_channels height, width = get_new_h_w(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import numpy as np import PIL import torch from PIL import Image from transformers import ( XLMRobertaTokenizer, ) from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDIMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from .text_encoder import MultilingualCLIP logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "A red cartoon frog, 4k" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyImg2ImgPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/frog.png" ... ) >>> image = pipe( ... prompt, ... image=init_image, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... strength=0.2, ... ).images >>> image[0].save("red_frog.png") ``` """ def get_new_h_w(h, w, scale_factor=8): new_h = h // scale_factor**2 if h % scale_factor**2 != 0: new_h += 1 new_w = w // scale_factor**2 if w % scale_factor**2 != 0: new_w += 1 return new_h * scale_factor, new_w * scale_factor def prepare_image(pil_image, w=512, h=512): pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) arr = np.array(pil_image.convert("RGB")) arr = arr.astype(np.float32) / 127.5 - 1 arr = np.transpose(arr, [2, 0, 1]) image = torch.from_numpy(arr).unsqueeze(0) return image class KandinskyImg2ImgPipeline(DiffusionPipeline): """ Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ image encoder and decoder """ def __init__( self, text_encoder: MultilingualCLIP, movq: VQModel, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, ): super().__init__() self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def prepare_latents(self, latents, latent_timestep, shape, dtype, device, generator, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma shape = latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.add_noise(latents, noise, latent_timestep) return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) text_mask = text_inputs.attention_mask.to(device) prompt_embeds, text_encoder_hidden_states = self.text_encoder( input_ids=text_input_ids, attention_mask=text_mask ) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) uncond_text_input_ids = uncond_input.input_ids.to(device) uncond_text_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.text_encoder, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # add_noise method to overwrite the one in schedule because it use a different beta schedule for adding noise vs sampling def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: betas = torch.linspace(0.0001, 0.02, 1000, dtype=torch.float32) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], image_embeds: torch.FloatTensor, negative_image_embeds: torch.FloatTensor, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, strength: float = 0.3, guidance_scale: float = 7.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.FloatTensor`, `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. strength (`float`, *optional*, defaults to 0.3): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ # 1. Define call parameters if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 # 2. get text and image embeddings prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=prompt_embeds.dtype, device=device ) # 3. pre-processing initial image if not isinstance(image, list): image = [image] if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" ) image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = image.to(dtype=prompt_embeds.dtype, device=device) latents = self.movq.encode(image)["latents"] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # the formular to calculate timestep for add_noise is taken from the original kandinsky repo latent_timestep = int(self.scheduler.config.num_train_timesteps * strength) - 2 latent_timestep = torch.tensor([latent_timestep] * batch_size, dtype=timesteps_tensor.dtype, device=device) num_channels_latents = self.unet.config.in_channels height, width = get_new_h_w(height, width, self.movq_scale_factor) # 5. Create initial latent latents = self.prepare_latents( latents, latent_timestep, (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, self.scheduler, ) # 6. Denoising loop for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample # 7. post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from typing import List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from PIL import Image from transformers import ( XLMRobertaTokenizer, ) from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDIMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from .text_encoder import MultilingualCLIP logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> import numpy as np >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "a hat" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyInpaintPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> mask = np.ones((768, 768), dtype=np.float32) >>> mask[:250, 250:-250] = 0 >>> out = pipe( ... prompt, ... image=init_image, ... mask_image=mask, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ) >>> image = out.images[0] >>> image.save("cat_with_hat.png") ``` """ def get_new_h_w(h, w, scale_factor=8): new_h = h // scale_factor**2 if h % scale_factor**2 != 0: new_h += 1 new_w = w // scale_factor**2 if w % scale_factor**2 != 0: new_w += 1 return new_h * scale_factor, new_w * scale_factor def prepare_mask(masks): prepared_masks = [] for mask in masks: old_mask = deepcopy(mask) for i in range(mask.shape[1]): for j in range(mask.shape[2]): if old_mask[0][i][j] == 1: continue if i != 0: mask[:, i - 1, j] = 0 if j != 0: mask[:, i, j - 1] = 0 if i != 0 and j != 0: mask[:, i - 1, j - 1] = 0 if i != mask.shape[1] - 1: mask[:, i + 1, j] = 0 if j != mask.shape[2] - 1: mask[:, i, j + 1] = 0 if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: mask[:, i + 1, j + 1] = 0 prepared_masks.append(mask) return torch.stack(prepared_masks, dim=0) def prepare_mask_and_masked_image(image, mask, height, width): r""" Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if image is None: raise ValueError("`image` input cannot be undefined.") if mask is None: raise ValueError("`mask_image` input cannot be undefined.") if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) return mask, image class KandinskyInpaintPipeline(DiffusionPipeline): """ Pipeline for text-guided image inpainting using Kandinsky2.1 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`MultilingualCLIP`]): Frozen text-encoder. tokenizer ([`XLMRobertaTokenizer`]): Tokenizer of class scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ image encoder and decoder """ def __init__( self, text_encoder: MultilingualCLIP, movq: VQModel, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, ): super().__init__() self.register_modules( text_encoder=text_encoder, movq=movq, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids.to(device) text_mask = text_inputs.attention_mask.to(device) prompt_embeds, text_encoder_hidden_states = self.text_encoder( input_ids=text_input_ids, attention_mask=text_mask ) prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=77, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) uncond_text_input_ids = uncond_input.input_ids.to(device) uncond_text_mask = uncond_input.attention_mask.to(device) negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.text_encoder, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], image_embeds: torch.FloatTensor, negative_image_embeds: torch.FloatTensor, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.FloatTensor`, `PIL.Image.Image` or `np.ndarray`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. mask_image (`PIL.Image.Image`,`torch.FloatTensor` or `np.ndarray`): `Image`, or a tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. You can pass a pytorch tensor as mask only if the image you passed is a pytorch tensor, and it should contain one color channel (L) instead of 3, so the expected shape would be either `(B, 1, H, W,)`, `(B, H, W)`, `(1, H, W)` or `(H, W)` If image is an PIL image or numpy array, mask should also be a either PIL image or numpy array. If it is a PIL image, it will be converted to a single channel (luminance) before use. If it is a nummpy array, the expected shape is `(H, W)`. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ # Define call parameters if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( dtype=prompt_embeds.dtype, device=device ) # preprocess image and mask mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) image = image.to(dtype=prompt_embeds.dtype, device=device) image = self.movq.encode(image)["latents"] mask_image = mask_image.to(dtype=prompt_embeds.dtype, device=device) image_shape = tuple(image.shape[-2:]) mask_image = F.interpolate( mask_image, image_shape, mode="nearest", ) mask_image = prepare_mask(mask_image) masked_image = image * mask_image mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: mask_image = mask_image.repeat(2, 1, 1, 1) masked_image = masked_image.repeat(2, 1, 1, 1) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.movq.config.latent_channels # get h, w for latents sample_height, sample_width = get_new_h_w(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, sample_height, sample_width), text_encoder_hidden_states.dtype, device, generator, latents, self.scheduler, ) # Check that sizes of mask, masked image and latents match with expected num_channels_mask = mask_image.shape[1] num_channels_masked_image = masked_image.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, ).prev_sample # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( BaseOutput, is_accelerate_available, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> out = pipe_prior(prompt) >>> image_emb = out.image_embeds >>> negative_image_emb = out.negative_image_embeds >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") >>> pipe.to("cuda") >>> image = pipe( ... prompt, ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images >>> image[0].save("cat.png") ``` """ EXAMPLE_INTERPOLATE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline >>> from diffusers.utils import load_image >>> import PIL >>> import torch >>> from torchvision import transforms >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> img1 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/starry_night.jpeg" ... ) >>> images_texts = ["a cat", img1, img2] >>> weights = [0.3, 0.3, 0.4] >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) >>> pipe.to("cuda") >>> image = pipe( ... "", ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=150, ... ).images[0] >>> image.save("starry_cat.png") ``` """ @dataclass class KandinskyPriorPipelineOutput(BaseOutput): """ Output class for KandinskyPriorPipeline. Args: image_embeds (`torch.FloatTensor`) clip image embeddings for text prompt negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`) clip image embeddings for unconditional tokens """ image_embeds: Union[torch.FloatTensor, np.ndarray] negative_image_embeds: Union[torch.FloatTensor, np.ndarray] class KandinskyPriorPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: UnCLIPScheduler, image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, ) @torch.no_grad() @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) def interpolate( self, images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], weights: List[float], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, negative_prior_prompt: Optional[str] = None, negative_prompt: Union[str] = "", guidance_scale: float = 4.0, device=None, ): """ Function invoked when using the prior pipeline for interpolation. Args: images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): list of prompts and images to guide the image generation. weights: (`List[float]`): list of weights for each condition in `images_and_prompts` num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. negative_prior_prompt (`str`, *optional*): The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ device = device or self.device if len(images_and_prompts) != len(weights): raise ValueError( f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" ) image_embeddings = [] for cond, weight in zip(images_and_prompts, weights): if isinstance(cond, str): image_emb = self( cond, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ).image_embeds elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): if isinstance(cond, PIL.Image.Image): cond = ( self.image_processor(cond, return_tensors="pt") .pixel_values[0] .unsqueeze(0) .to(dtype=self.image_encoder.dtype, device=device) ) image_emb = self.image_encoder(cond)["image_embeds"] else: raise ValueError( f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" ) image_embeddings.append(image_emb * weight) image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True) out_zero = self( negative_prompt, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ) zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def get_zero_embed(self, batch_size=1, device=None): device = device or self.device zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( device=device, dtype=self.image_encoder.dtype ) zero_image_emb = self.image_encoder(zero_img)["image_embeds"] zero_image_emb = zero_image_emb.repeat(batch_size, 1) return zero_image_emb def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.image_encoder, self.text_encoder, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): return self.device for module in self.text_encoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, output_type: Optional[str] = "pt", # pt only return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] elif not isinstance(negative_prompt, list) and negative_prompt is not None: raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") # if the negative prompt is defined we double the batch size to # directly retrieve the negative prompt embedding if negative_prompt is not None: prompt = prompt + negative_prompt negative_prompt = 2 * negative_prompt device = self._execution_device batch_size = len(prompt) batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) prior_timesteps_tensor = self.scheduler.timesteps embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, embedding_dim), prompt_embeds.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == prior_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = prior_timesteps_tensor[i + 1] latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample latents = self.prior.post_process_latents(latents) image_embeddings = latents # if negative prompt has been defined, we retrieve split the image embedding into two if negative_prompt is None: zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) else: image_embeddings, zero_embeds = image_embeddings.chunk(2) if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": image_embeddings = image_embeddings.cpu().numpy() zero_embeds = zero_embeds.cpu().numpy() if not return_dict: return (image_embeddings, zero_embeds) return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky/text_encoder.py ================================================ import torch from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel class MCLIPConfig(XLMRobertaConfig): model_type = "M-CLIP" def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs): self.transformerDimensions = transformerDimSize self.numDims = imageDimSize super().__init__(**kwargs) class MultilingualCLIP(PreTrainedModel): config_class = MCLIPConfig def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) self.transformer = XLMRobertaModel(config) self.LinearTransformation = torch.nn.Linear( in_features=config.transformerDimensions, out_features=config.numDims ) def forward(self, input_ids, attention_mask): embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] return self.LinearTransformation(embs2), embs ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky2_2/__init__.py ================================================ from .pipeline_kandinsky2_2 import KandinskyV22Pipeline from .pipeline_kandinsky2_2_controlnet import KandinskyV22ControlnetPipeline from .pipeline_kandinsky2_2_controlnet_img2img import KandinskyV22ControlnetImg2ImgPipeline from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline from .pipeline_kandinsky2_2_inpainting import KandinskyV22InpaintPipeline from .pipeline_kandinsky2_2_prior import KandinskyV22PriorPipeline from .pipeline_kandinsky2_2_prior_emb2emb import KandinskyV22PriorEmb2EmbPipeline ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import torch from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline >>> import torch >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> out = pipe_prior(prompt) >>> image_emb = out.image_embeds >>> zero_image_emb = out.negative_image_embeds >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ).images >>> image[0].save("cat.png") ``` """ def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor class KandinskyV22Pipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler (Union[`DDIMScheduler`,`DDPMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Args: Function invoked when calling the pipeline for generation. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) batch_size = image_embeds.shape[0] * num_images_per_prompt if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.unet.config.in_channels height, width = downscale_height_and_width(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), image_embeds.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import torch from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> import numpy as np >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22ControlnetPipeline >>> from transformers import pipeline >>> from diffusers.utils import load_image >>> def make_hint(image, depth_estimator): ... image = depth_estimator(image)["depth"] ... image = np.array(image) ... image = image[:, :, None] ... image = np.concatenate([image, image, image], axis=2) ... detected_map = torch.from_numpy(image).float() / 255.0 ... hint = detected_map.permute(2, 0, 1) ... return hint >>> depth_estimator = pipeline("depth-estimation") >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior = pipe_prior.to("cuda") >>> pipe = KandinskyV22ControlnetPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> img = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ).resize((768, 768)) >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") >>> prompt = "A robot, 4k photo" >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" >>> generator = torch.Generator(device="cuda").manual_seed(43) >>> image_emb, zero_image_emb = pipe_prior( ... prompt=prompt, negative_prompt=negative_prior_prompt, generator=generator ... ).to_tuple() >>> images = pipe( ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... hint=hint, ... num_inference_steps=50, ... generator=generator, ... height=768, ... width=768, ... ).images >>> images[0].save("robot_cat.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor class KandinskyV22ControlnetPipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], hint: torch.FloatTensor, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. hint (`torch.FloatTensor`): The controlnet condition. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if isinstance(hint, list): hint = torch.cat(hint, dim=0) batch_size = image_embeds.shape[0] * num_images_per_prompt if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) hint = hint.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps num_channels_latents = self.movq.config.latent_channels height, width = downscale_height_and_width(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), image_embeds.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import numpy as np import PIL import torch from PIL import Image from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> import numpy as np >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22ControlnetImg2ImgPipeline >>> from transformers import pipeline >>> from diffusers.utils import load_image >>> def make_hint(image, depth_estimator): ... image = depth_estimator(image)["depth"] ... image = np.array(image) ... image = image[:, :, None] ... image = np.concatenate([image, image, image], axis=2) ... detected_map = torch.from_numpy(image).float() / 255.0 ... hint = detected_map.permute(2, 0, 1) ... return hint >>> depth_estimator = pipeline("depth-estimation") >>> pipe_prior = KandinskyV22PriorEmb2EmbPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior = pipe_prior.to("cuda") >>> pipe = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> img = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ).resize((768, 768)) >>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda") >>> prompt = "A robot, 4k photo" >>> negative_prior_prompt = "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature" >>> generator = torch.Generator(device="cuda").manual_seed(43) >>> img_emb = pipe_prior(prompt=prompt, image=img, strength=0.85, generator=generator) >>> negative_emb = pipe_prior(prompt=negative_prior_prompt, image=img, strength=1, generator=generator) >>> images = pipe( ... image=img, ... strength=0.5, ... image_embeds=img_emb.image_embeds, ... negative_image_embeds=negative_emb.image_embeds, ... hint=hint, ... num_inference_steps=50, ... generator=generator, ... height=768, ... width=768, ... ).images >>> images[0].save("robot_cat.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image def prepare_image(pil_image, w=512, h=512): pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) arr = np.array(pil_image.convert("RGB")) arr = arr.astype(np.float32) / 127.5 - 1 arr = np.transpose(arr, [2, 0, 1]) image = torch.from_numpy(arr).unsqueeze(0) return image class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline): """ Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2_img2img.KandinskyV22Img2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.movq.encode(image).latent_dist.sample(generator) init_latents = self.movq.config.scaling_factor * init_latents init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], hint: torch.FloatTensor, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, strength: float = 0.3, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. hint (`torch.FloatTensor`): The controlnet condition. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if isinstance(hint, list): hint = torch.cat(hint, dim=0) batch_size = image_embeds.shape[0] if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) hint = hint.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device) if not isinstance(image, list): image = [image] if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" ) image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = image.to(dtype=image_embeds.dtype, device=device) latents = self.movq.encode(image)["latents"] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) height, width = downscale_height_and_width(height, width, self.movq_scale_factor) latents = self.prepare_latents( latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator ) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds, "hint": hint} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Union import numpy as np import PIL import torch from PIL import Image from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Img2ImgPipeline, KandinskyV22PriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "A red cartoon frog, 4k" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyV22Img2ImgPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/frog.png" ... ) >>> image = pipe( ... image=init_image, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... strength=0.2, ... ).images >>> image[0].save("red_frog.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image def prepare_image(pil_image, w=512, h=512): pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) arr = np.array(pil_image.convert("RGB")) arr = arr.astype(np.float32) / 127.5 - 1 arr = np.transpose(arr, [2, 0, 1]) image = torch.from_numpy(arr).unsqueeze(0) return image class KandinskyV22Img2ImgPipeline(DiffusionPipeline): """ Pipeline for image-to-image generation using Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.movq.encode(image).latent_dist.sample(generator) init_latents = self.movq.config.scaling_factor * init_latents init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, strength: float = 0.3, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) batch_size = image_embeds.shape[0] if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) if not isinstance(image, list): image = [image] if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" ) image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) image = image.to(dtype=image_embeds.dtype, device=device) latents = self.movq.encode(image)["latents"] latents = latents.repeat_interleave(num_images_per_prompt, dim=0) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) height, width = downscale_height_and_width(height, width, self.movq_scale_factor) latents = self.prepare_latents( latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator ) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents added_cond_kwargs = {"image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from typing import List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from PIL import Image from ...models import UNet2DConditionModel, VQModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import DDPMScheduler from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22InpaintPipeline, KandinskyV22PriorPipeline >>> from diffusers.utils import load_image >>> import torch >>> import numpy as np >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "a hat" >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) >>> pipe = KandinskyV22InpaintPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> init_image = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> mask = np.ones((768, 768), dtype=np.float32) >>> mask[:250, 250:-250] = 0 >>> out = pipe( ... image=init_image, ... mask_image=mask, ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ) >>> image = out.images[0] >>> image.save("cat_with_hat.png") ``` """ # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width def downscale_height_and_width(height, width, scale_factor=8): new_height = height // scale_factor**2 if height % scale_factor**2 != 0: new_height += 1 new_width = width // scale_factor**2 if width % scale_factor**2 != 0: new_width += 1 return new_height * scale_factor, new_width * scale_factor # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask def prepare_mask(masks): prepared_masks = [] for mask in masks: old_mask = deepcopy(mask) for i in range(mask.shape[1]): for j in range(mask.shape[2]): if old_mask[0][i][j] == 1: continue if i != 0: mask[:, i - 1, j] = 0 if j != 0: mask[:, i, j - 1] = 0 if i != 0 and j != 0: mask[:, i - 1, j - 1] = 0 if i != mask.shape[1] - 1: mask[:, i + 1, j] = 0 if j != mask.shape[2] - 1: mask[:, i, j + 1] = 0 if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: mask[:, i + 1, j + 1] = 0 prepared_masks.append(mask) return torch.stack(prepared_masks, dim=0) # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_inpaint.prepare_mask_and_masked_image def prepare_mask_and_masked_image(image, mask, height, width): r""" Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if image is None: raise ValueError("`image` input cannot be undefined.") if mask is None: raise ValueError("`mask_image` input cannot be undefined.") if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) return mask, image class KandinskyV22InpaintPipeline(DiffusionPipeline): """ Pipeline for text-guided image inpainting using Kandinsky2.1 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: scheduler ([`DDIMScheduler`]): A scheduler to be used in combination with `unet` to generate image latents. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the image embedding. movq ([`VQModel`]): MoVQ Decoder to generate the image from the latents. """ def __init__( self, unet: UNet2DConditionModel, scheduler: DDPMScheduler, movq: VQModel, ): super().__init__() self.register_modules( unet=unet, scheduler=scheduler, movq=movq, ) self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.unet, self.movq, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) # Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.KandinskyV22Pipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.movq]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Args: Function invoked when calling the pipeline for generation. image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for text prompt, that will be used to condition the image generation. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`np.array`): Tensor representing an image batch, to mask `image`. Black pixels in the mask will be repainted, while white pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(image_embeds, list): image_embeds = torch.cat(image_embeds, dim=0) batch_size = image_embeds.shape[0] * num_images_per_prompt if isinstance(negative_image_embeds, list): negative_image_embeds = torch.cat(negative_image_embeds, dim=0) if do_classifier_free_guidance: image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps_tensor = self.scheduler.timesteps # preprocess image and mask mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) image = image.to(dtype=image_embeds.dtype, device=device) image = self.movq.encode(image)["latents"] mask_image = mask_image.to(dtype=image_embeds.dtype, device=device) image_shape = tuple(image.shape[-2:]) mask_image = F.interpolate( mask_image, image_shape, mode="nearest", ) mask_image = prepare_mask(mask_image) masked_image = image * mask_image mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: mask_image = mask_image.repeat(2, 1, 1, 1) masked_image = masked_image.repeat(2, 1, 1, 1) num_channels_latents = self.movq.config.latent_channels height, width = downscale_height_and_width(height, width, self.movq_scale_factor) # create initial latent latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), image_embeds.dtype, device, generator, latents, self.scheduler, ) noise = torch.clone(latents) for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) added_cond_kwargs = {"image_embeds": image_embeds} noise_pred = self.unet( sample=latent_model_input, timestep=t, encoder_hidden_states=None, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) _, variance_pred_text = variance_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) if not ( hasattr(self.scheduler.config, "variance_type") and self.scheduler.config.variance_type in ["learned", "learned_range"] ): noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, generator=generator, )[0] init_latents_proper = image[:1] init_mask = mask_image[:1] if i < len(timesteps_tensor) - 1: noise_timestep = timesteps_tensor[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = init_mask * init_latents_proper + (1 - init_mask) * latents # post-processing latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents image = self.movq.decode(latents, force_not_quantize=True)["sample"] if output_type not in ["pt", "np", "pil"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") if output_type in ["np", "pil"]: image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py ================================================ from typing import List, Optional, Union import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( is_accelerate_available, logging, randn_tensor, replace_example_docstring, ) from ..kandinsky import KandinskyPriorPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorPipeline >>> import torch >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior") >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> image_emb, negative_image_emb = pipe_prior(prompt).to_tuple() >>> pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder") >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=50, ... ).images >>> image[0].save("cat.png") ``` """ EXAMPLE_INTERPOLATE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline >>> from diffusers.utils import load_image >>> import PIL >>> import torch >>> from torchvision import transforms >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> img1 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/starry_night.jpeg" ... ) >>> images_texts = ["a cat", img1, img2] >>> weights = [0.3, 0.3, 0.4] >>> out = pipe_prior.interpolate(images_texts, weights) >>> pipe = KandinskyV22Pipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=out.image_embeds, ... negative_image_embeds=out.negative_image_embeds, ... height=768, ... width=768, ... num_inference_steps=50, ... ).images[0] >>> image.save("starry_cat.png") ``` """ class KandinskyV22PriorPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. image_processor ([`CLIPImageProcessor`]): A image_processor to be used to preprocess image from clip. """ def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: UnCLIPScheduler, image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, ) @torch.no_grad() @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) def interpolate( self, images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], weights: List[float], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, negative_prior_prompt: Optional[str] = None, negative_prompt: Union[str] = "", guidance_scale: float = 4.0, device=None, ): """ Function invoked when using the prior pipeline for interpolation. Args: images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): list of prompts and images to guide the image generation. weights: (`List[float]`): list of weights for each condition in `images_and_prompts` num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. negative_prior_prompt (`str`, *optional*): The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ device = device or self.device if len(images_and_prompts) != len(weights): raise ValueError( f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" ) image_embeddings = [] for cond, weight in zip(images_and_prompts, weights): if isinstance(cond, str): image_emb = self( cond, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ).image_embeds.unsqueeze(0) elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): if isinstance(cond, PIL.Image.Image): cond = ( self.image_processor(cond, return_tensors="pt") .pixel_values[0] .unsqueeze(0) .to(dtype=self.image_encoder.dtype, device=device) ) image_emb = self.image_encoder(cond)["image_embeds"].repeat(num_images_per_prompt, 1).unsqueeze(0) else: raise ValueError( f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" ) image_embeddings.append(image_emb * weight) image_emb = torch.cat(image_embeddings).sum(dim=0) out_zero = self( negative_prompt, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ) zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed def get_zero_embed(self, batch_size=1, device=None): device = device or self.device zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( device=device, dtype=self.image_encoder.dtype ) zero_image_emb = self.image_encoder(zero_img)["image_embeds"] zero_image_emb = zero_image_emb.repeat(batch_size, 1) return zero_image_emb # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.image_encoder, self.text_encoder, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): return self.device for module in self.text_encoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, output_type: Optional[str] = "pt", # pt only return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] elif not isinstance(negative_prompt, list) and negative_prompt is not None: raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") # if the negative prompt is defined we double the batch size to # directly retrieve the negative prompt embedding if negative_prompt is not None: prompt = prompt + negative_prompt negative_prompt = 2 * negative_prompt device = self._execution_device batch_size = len(prompt) batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) prior_timesteps_tensor = self.scheduler.timesteps embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, embedding_dim), prompt_embeds.dtype, device, generator, latents, self.scheduler, ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == prior_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = prior_timesteps_tensor[i + 1] latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample latents = self.prior.post_process_latents(latents) image_embeddings = latents # if negative prompt has been defined, we retrieve split the image embedding into two if negative_prompt is None: zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) else: image_embeddings, zero_embeds = image_embeddings.chunk(2) if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": image_embeddings = image_embeddings.cpu().numpy() zero_embeds = zero_embeds.cpu().numpy() if not return_dict: return (image_embeddings, zero_embeds) return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) ================================================ FILE: 6DoF/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py ================================================ from typing import List, Optional, Union import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import PriorTransformer from ...pipelines import DiffusionPipeline from ...schedulers import UnCLIPScheduler from ...utils import ( is_accelerate_available, logging, randn_tensor, replace_example_docstring, ) from ..kandinsky import KandinskyPriorPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22Pipeline, KandinskyV22PriorEmb2EmbPipeline >>> import torch >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> prompt = "red cat, 4k photo" >>> img = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple() >>> pipe = KandinskyPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder, torch_dtype=torch.float16" ... ) >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=negative_image_emb, ... height=768, ... width=768, ... num_inference_steps=100, ... ).images >>> image[0].save("cat.png") ``` """ EXAMPLE_INTERPOLATE_DOC_STRING = """ Examples: ```py >>> from diffusers import KandinskyV22PriorEmb2EmbPipeline, KandinskyV22Pipeline >>> from diffusers.utils import load_image >>> import PIL >>> import torch >>> from torchvision import transforms >>> pipe_prior = KandinskyV22PriorPipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16 ... ) >>> pipe_prior.to("cuda") >>> img1 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/cat.png" ... ) >>> img2 = load_image( ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" ... "/kandinsky/starry_night.jpeg" ... ) >>> images_texts = ["a cat", img1, img2] >>> weights = [0.3, 0.3, 0.4] >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) >>> pipe = KandinskyV22Pipeline.from_pretrained( ... "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 ... ) >>> pipe.to("cuda") >>> image = pipe( ... image_embeds=image_emb, ... negative_image_embeds=zero_image_emb, ... height=768, ... width=768, ... num_inference_steps=150, ... ).images[0] >>> image.save("starry_cat.png") ``` """ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline): """ Pipeline for generating image prior for Kandinsky This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen image-encoder. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`UnCLIPScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. """ def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModelWithProjection, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: UnCLIPScheduler, image_processor: CLIPImageProcessor, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start @torch.no_grad() @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) def interpolate( self, images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], weights: List[float], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, negative_prior_prompt: Optional[str] = None, negative_prompt: Union[str] = "", guidance_scale: float = 4.0, device=None, ): """ Function invoked when using the prior pipeline for interpolation. Args: images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): list of prompts and images to guide the image generation. weights: (`List[float]`): list of weights for each condition in `images_and_prompts` num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. negative_prior_prompt (`str`, *optional*): The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). negative_prompt (`str` or `List[str]`, *optional*): The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ device = device or self.device if len(images_and_prompts) != len(weights): raise ValueError( f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" ) image_embeddings = [] for cond, weight in zip(images_and_prompts, weights): if isinstance(cond, str): image_emb = self( cond, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, generator=generator, latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ).image_embeds.unsqueeze(0) elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): image_emb = self._encode_image( cond, device=device, num_images_per_prompt=num_images_per_prompt ).unsqueeze(0) else: raise ValueError( f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" ) image_embeddings.append(image_emb * weight) image_emb = torch.cat(image_embeddings).sum(dim=0) return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=torch.randn_like(image_emb)) def _encode_image( self, image: Union[torch.Tensor, List[PIL.Image.Image]], device, num_images_per_prompt, ): if not isinstance(image, torch.Tensor): image = self.image_processor(image, return_tensors="pt").pixel_values.to( dtype=self.image_encoder.dtype, device=device ) image_emb = self.image_encoder(image)["image_embeds"] # B, D image_emb = image_emb.repeat_interleave(num_images_per_prompt, dim=0) image_emb.to(device=device) return image_emb def prepare_latents(self, emb, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): emb = emb.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt init_latents = emb if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.get_zero_embed def get_zero_embed(self, batch_size=1, device=None): device = device or self.device zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( device=device, dtype=self.image_encoder.dtype ) zero_image_emb = self.image_encoder(zero_img)["image_embeds"] zero_image_emb = zero_image_emb.repeat(batch_size, 1) return zero_image_emb # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.image_encoder, self.text_encoder, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): return self.device for module in self.text_encoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_prior.KandinskyPriorPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, ): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.Tensor, List[torch.Tensor], PIL.Image.Image, List[PIL.Image.Image]], strength: float = 0.3, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, output_type: Optional[str] = "pt", # pt only return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `emb`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. emb (`torch.FloatTensor`): The image embedding. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`KandinskyPriorPipelineOutput`] or `tuple` """ if isinstance(prompt, str): prompt = [prompt] elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] elif not isinstance(negative_prompt, list) and negative_prompt is not None: raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") # if the negative prompt is defined we double the batch size to # directly retrieve the negative prompt embedding if negative_prompt is not None: prompt = prompt + negative_prompt negative_prompt = 2 * negative_prompt device = self._execution_device batch_size = len(prompt) batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) if not isinstance(image, List): image = [image] if isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) if isinstance(image, torch.Tensor) and image.ndim == 2: # allow user to pass image_embeds directly image_embeds = image.repeat_interleave(num_images_per_prompt, dim=0) elif isinstance(image, torch.Tensor) and image.ndim != 4: raise ValueError( f" if pass `image` as pytorch tensor, or a list of pytorch tensor, please make sure each tensor has shape [batch_size, channels, height, width], currently {image[0].unsqueeze(0).shape}" ) else: image_embeds = self._encode_image(image, device, num_images_per_prompt) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) latents = image_embeds timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size) latents = self.prepare_latents( latents, latent_timestep, batch_size // num_images_per_prompt, num_images_per_prompt, prompt_embeds.dtype, device, generator, ) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == timesteps.shape[0]: prev_timestep = None else: prev_timestep = timesteps[i + 1] latents = self.scheduler.step( predicted_image_embedding, timestep=t, sample=latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample latents = self.prior.post_process_latents(latents) image_embeddings = latents # if negative prompt has been defined, we retrieve split the image embedding into two if negative_prompt is None: zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) else: image_embeddings, zero_embeds = image_embeddings.chunk(2) if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") if output_type == "np": image_embeddings = image_embeddings.cpu().numpy() zero_embeds = zero_embeds.cpu().numpy() if not return_dict: return (image_embeddings, zero_embeds) return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) ================================================ FILE: 6DoF/diffusers/pipelines/latent_diffusion/__init__.py ================================================ from ...utils import is_transformers_available from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline if is_transformers_available(): from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline ================================================ FILE: 6DoF/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput from transformers.utils import logging from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class LDMTextToImagePipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. tokenizer (`transformers.BertTokenizer`): Tokenizer of class [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vqvae: Union[VQModel, AutoencoderKL], bert: PreTrainedModel, tokenizer: PreTrainedTokenizer, unet: Union[UNet2DModel, UNet2DConditionModel], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], ): super().__init__() self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 1.0, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: r""" Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at the, usually at the expense of lower image quality. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get unconditional embeddings for classifier free guidance if guidance_scale != 1.0: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" ) negative_prompt_embeds = self.bert(uncond_input.input_ids.to(self.device))[0] # get prompt text embeddings text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") prompt_embeds = self.bert(text_input.input_ids.to(self.device))[0] # get the initial random noise unless the user supplied it latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=prompt_embeds.dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") latents = latents.to(self.device) self.scheduler.set_timesteps(num_inference_steps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_kwargs = {} if accepts_eta: extra_kwargs["eta"] = eta for t in self.progress_bar(self.scheduler.timesteps): if guidance_scale == 1.0: # guidance_scale of 1 means no guidance latents_input = latents context = prompt_embeds else: # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = torch.cat([latents] * 2) context = torch.cat([negative_prompt_embeds, prompt_embeds]) # predict the noise residual noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample # perform guidance if guidance_scale != 1.0: noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / self.vqvae.config.scaling_factor * latents image = self.vqvae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ################################################################################ # Code for the text transformer model ################################################################################ """ PyTorch LDMBERT model.""" logger = logging.get_logger(__name__) LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ "ldm-bert", # See all LDMBert models at https://huggingface.co/models?filter=ldmbert ] LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { "ldm-bert": "https://huggingface.co/valhalla/ldm-bert/blob/main/config.json", } """ LDMBERT model configuration""" class LDMBertConfig(PretrainedConfig): model_type = "ldmbert" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} def __init__( self, vocab_size=30522, max_position_embeddings=77, encoder_layers=32, encoder_ffn_dim=5120, encoder_attention_heads=8, head_dim=64, encoder_layerdrop=0.0, activation_function="gelu", d_model=1280, dropout=0.1, attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, classifier_dropout=0.0, scale_embedding=False, use_cache=True, pad_token_id=0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.d_model = d_model self.encoder_ffn_dim = encoder_ffn_dim self.encoder_layers = encoder_layers self.encoder_attention_heads = encoder_attention_heads self.head_dim = head_dim self.dropout = dropout self.attention_dropout = attention_dropout self.activation_dropout = activation_dropout self.activation_function = activation_function self.init_std = init_std self.encoder_layerdrop = encoder_layerdrop self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__(pad_token_id=pad_token_id, **kwargs) def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert class LDMBertAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, embed_dim: int, num_heads: int, head_dim: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = False, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = head_dim self.inner_dim = head_dim * num_heads self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) self.out_proj = nn.Linear(self.inner_dim, embed_dim) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value class LDMBertEncoderLayer(nn.Module): def __init__(self, config: LDMBertConfig): super().__init__() self.embed_dim = config.d_model self.self_attn = LDMBertAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, head_dim=config.head_dim, dropout=config.attention_dropout, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.FloatTensor, attention_mask: torch.FloatTensor, layer_head_mask: torch.FloatTensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs # Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert class LDMBertPreTrainedModel(PreTrainedModel): config_class = LDMBertConfig base_model_prefix = "model" _supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (LDMBertEncoder,)): module.gradient_checkpointing = value @property def dummy_inputs(self): pad_token = self.config.pad_token_id input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) dummy_inputs = { "attention_mask": input_ids.ne(pad_token), "input_ids": input_ids, } return dummy_inputs class LDMBertEncoder(LDMBertPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`LDMBertEncoderLayer`]. Args: config: LDMBertConfig embed_tokens (nn.Embedding): output embedding """ def __init__(self, config: LDMBertConfig): super().__init__(config) self.dropout = config.dropout embed_dim = config.d_model self.padding_idx = config.pad_token_id self.max_source_positions = config.max_position_embeddings self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) seq_len = input_shape[1] if position_ids is None: position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) embed_pos = self.embed_positions(position_ids) hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: if head_mask.size()[0] != (len(self.layers)): raise ValueError( f"The head_mask should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) class LDMBertModel(LDMBertPreTrainedModel): _no_split_modules = [] def __init__(self, config: LDMBertConfig): super().__init__(config) self.model = LDMBertEncoder(config) self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return outputs ================================================ FILE: 6DoF/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py ================================================ import inspect from typing import List, Optional, Tuple, Union import numpy as np import PIL import torch import torch.utils.checkpoint from ...models import UNet2DModel, VQModel from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from ...utils import PIL_INTERPOLATION, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput def preprocess(image): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 class LDMSuperResolutionPipeline(DiffusionPipeline): r""" A pipeline for image super-resolution using Latent This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations. unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vqvae: VQModel, unet: UNet2DModel, scheduler: Union[ DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], ): super().__init__() self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, image: Union[torch.Tensor, PIL.Image.Image] = None, batch_size: Optional[int] = 1, num_inference_steps: Optional[int] = 100, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[Tuple, ImagePipelineOutput]: r""" Args: image (`torch.Tensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. batch_size (`int`, *optional*, defaults to 1): Number of images to generate. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, torch.Tensor): batch_size = image.shape[0] else: raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") if isinstance(image, PIL.Image.Image): image = preprocess(image) height, width = image.shape[-2:] # in_channels should be 6: 3 for latents, 3 for low resolution image latents_shape = (batch_size, self.unet.config.in_channels // 2, height, width) latents_dtype = next(self.unet.parameters()).dtype latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) image = image.to(device=self.device, dtype=latents_dtype) # set timesteps and move to the correct device self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps_tensor = self.scheduler.timesteps # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_kwargs = {} if accepts_eta: extra_kwargs["eta"] = eta for t in self.progress_bar(timesteps_tensor): # concat latents and low resolution image in the channel dimension. latents_input = torch.cat([latents, image], dim=1) latents_input = self.scheduler.scale_model_input(latents_input, t) # predict the noise residual noise_pred = self.unet(latents_input, t).sample # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample # decode the image latents with the VQVAE image = self.vqvae.decode(latents).sample image = torch.clamp(image, -1.0, 1.0) image = image / 2 + 0.5 image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/latent_diffusion_uncond/__init__.py ================================================ from .pipeline_latent_diffusion_uncond import LDMPipeline ================================================ FILE: 6DoF/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Tuple, Union import torch from ...models import UNet2DModel, VQModel from ...schedulers import DDIMScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class LDMPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents. """ def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler): super().__init__() self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, eta: float = 0.0, num_inference_steps: int = 50, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: r""" Args: batch_size (`int`, *optional*, defaults to 1): Number of images to generate. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ latents = randn_tensor( (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, ) latents = latents.to(self.device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma self.scheduler.set_timesteps(num_inference_steps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_kwargs = {} if accepts_eta: extra_kwargs["eta"] = eta for t in self.progress_bar(self.scheduler.timesteps): latent_model_input = self.scheduler.scale_model_input(latents, t) # predict the noise residual noise_prediction = self.unet(latent_model_input, t).sample # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample # decode the image latents with the VAE image = self.vqvae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/onnx_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil from pathlib import Path from typing import Optional, Union import numpy as np from huggingface_hub import hf_hub_download from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging if is_onnx_available(): import onnxruntime as ort logger = logging.get_logger(__name__) ORT_TO_NP_TYPE = { "tensor(bool)": np.bool_, "tensor(int8)": np.int8, "tensor(uint8)": np.uint8, "tensor(int16)": np.int16, "tensor(uint16)": np.uint16, "tensor(int32)": np.int32, "tensor(uint32)": np.uint32, "tensor(int64)": np.int64, "tensor(uint64)": np.uint64, "tensor(float16)": np.float16, "tensor(float)": np.float32, "tensor(double)": np.float64, } class OnnxRuntimeModel: def __init__(self, model=None, **kwargs): logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") self.model = model self.model_save_dir = kwargs.get("model_save_dir", None) self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME) def __call__(self, **kwargs): inputs = {k: np.array(v) for k, v in kwargs.items()} return self.model.run(None, inputs) @staticmethod def load_model(path: Union[str, Path], provider=None, sess_options=None): """ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` Arguments: path (`str` or `Path`): Directory from which to load provider(`str`, *optional*): Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` """ if provider is None: logger.info("No onnxruntime provider specified, using CPUExecutionProvider") provider = "CPUExecutionProvider" return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the latest_model_name. Arguments: save_directory (`str` or `Path`): Directory where to save the model file. file_name(`str`, *optional*): Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the model with a different name. """ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME src_path = self.model_save_dir.joinpath(self.latest_model_name) dst_path = Path(save_directory).joinpath(model_file_name) try: shutil.copyfile(src_path, dst_path) except shutil.SameFileError: pass # copy external weights (for models >2GB) src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) if src_path.exists(): dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME) try: shutil.copyfile(src_path, dst_path) except shutil.SameFileError: pass def save_pretrained( self, save_directory: Union[str, os.PathLike], **kwargs, ): """ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class method.: Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return os.makedirs(save_directory, exist_ok=True) # saving model weights/files self._save_pretrained(save_directory, **kwargs) @classmethod def _from_pretrained( cls, model_id: Union[str, Path], use_auth_token: Optional[Union[bool, str, None]] = None, revision: Optional[Union[str, None]] = None, force_download: bool = False, cache_dir: Optional[str] = None, file_name: Optional[str] = None, provider: Optional[str] = None, sess_options: Optional["ort.SessionOptions"] = None, **kwargs, ): """ Load a model from a directory or the HF Hub. Arguments: model_id (`str` or `Path`): Directory from which to load use_auth_token (`str` or `bool`): Is needed to load models from a private or gated repository revision (`str`): Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id cache_dir (`Union[str, Path]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. file_name(`str`): Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load different model files from the same repository or directory. provider(`str`): The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. kwargs (`Dict`, *optional*): kwargs will be passed to the model during initialization """ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME # load model from local directory if os.path.isdir(model_id): model = OnnxRuntimeModel.load_model( os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options ) kwargs["model_save_dir"] = Path(model_id) # load model from hub else: # download model model_cache_path = hf_hub_download( repo_id=model_id, filename=model_file_name, use_auth_token=use_auth_token, revision=revision, cache_dir=cache_dir, force_download=force_download, ) kwargs["model_save_dir"] = Path(model_cache_path).parent kwargs["latest_model_name"] = Path(model_cache_path).name model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options) return cls(model=model, **kwargs) @classmethod def from_pretrained( cls, model_id: Union[str, Path], force_download: bool = True, use_auth_token: Optional[str] = None, cache_dir: Optional[str] = None, **model_kwargs, ): revision = None if len(str(model_id).split("@")) == 2: model_id, revision = model_id.split("@") return cls._from_pretrained( model_id=model_id, revision=revision, cache_dir=cache_dir, force_download=force_download, use_auth_token=use_auth_token, **model_kwargs, ) ================================================ FILE: 6DoF/diffusers/pipelines/paint_by_example/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import is_torch_available, is_transformers_available if is_transformers_available() and is_torch_available(): from .image_encoder import PaintByExampleImageEncoder from .pipeline_paint_by_example import PaintByExamplePipeline ================================================ FILE: 6DoF/diffusers/pipelines/paint_by_example/image_encoder.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from transformers import CLIPPreTrainedModel, CLIPVisionModel from ...models.attention import BasicTransformerBlock from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name class PaintByExampleImageEncoder(CLIPPreTrainedModel): def __init__(self, config, proj_size=768): super().__init__(config) self.proj_size = proj_size self.model = CLIPVisionModel(config) self.mapper = PaintByExampleMapper(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size) self.proj_out = nn.Linear(config.hidden_size, self.proj_size) # uncondition for scaling self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) def forward(self, pixel_values, return_uncond_vector=False): clip_output = self.model(pixel_values=pixel_values) latent_states = clip_output.pooler_output latent_states = self.mapper(latent_states[:, None]) latent_states = self.final_layer_norm(latent_states) latent_states = self.proj_out(latent_states) if return_uncond_vector: return latent_states, self.uncond_vector return latent_states class PaintByExampleMapper(nn.Module): def __init__(self, config): super().__init__() num_layers = (config.num_hidden_layers + 1) // 5 hid_size = config.hidden_size num_heads = 1 self.blocks = nn.ModuleList( [ BasicTransformerBlock(hid_size, num_heads, hid_size, activation_fn="gelu", attention_bias=True) for _ in range(num_layers) ] ) def forward(self, hidden_states): for block in self.blocks: hidden_states = block(hidden_states) return hidden_states ================================================ FILE: 6DoF/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor from diffusers.utils import is_accelerate_available from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from .image_encoder import PaintByExampleImageEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name def prepare_mask_and_masked_image(image, mask): """ Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Batched mask if mask.shape[0] == image.shape[0]: mask = mask.unsqueeze(1) else: mask = mask.unsqueeze(0) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" assert mask.shape[1] == 1, "Mask image must have a single channel" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # paint-by-example inverses the mask mask = 1 - mask # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: if isinstance(image, PIL.Image.Image): image = [image] image = np.concatenate([np.array(i.convert("RGB"))[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, PIL.Image.Image): mask = [mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 # paint-by-example inverses the mask mask = 1 - mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) masked_image = image * mask return mask, masked_image class PaintByExamplePipeline(DiffusionPipeline): r""" Pipeline for image-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. image_encoder ([`PaintByExampleImageEncoder`]): Encodes the example input image. The unet is conditioned on the example image instead of a text prompt. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ # TODO: feature_extractor is required to encode initial images (if they are in PIL format), # we should give a descriptive message if the pipeline doesn't have one. _optional_components = ["safety_checker"] def __init__( self, vae: AutoencoderKL, image_encoder: PaintByExampleImageEncoder, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = False, ): super().__init__() self.register_modules( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.vae, self.image_encoder]: cpu_offload(cpu_offloaded_model, execution_device=device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = self.vae.config.scaling_factor * image_latents return image_latents def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeddings, negative_prompt_embeds = self.image_encoder(image, return_uncond_vector=True) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: negative_prompt_embeds = negative_prompt_embeds.repeat(1, image_embeddings.shape[0], 1) negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, 1, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings @torch.no_grad() def __call__( self, example_image: Union[torch.FloatTensor, PIL.Image.Image], image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.FloatTensor, PIL.Image.Image], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: example_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): The exemplar image to guide the image generation. image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Define call parameters if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) else: batch_size = image.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 2. Preprocess mask and image mask, masked_image = prepare_mask_and_masked_image(image, mask_image) height, width = masked_image.shape[-2:] # 3. Check inputs self.check_inputs(example_image, height, width, callback_steps) # 4. Encode input image image_embeddings = self._encode_image( example_image, device, num_images_per_prompt, do_classifier_free_guidance ) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, image_embeddings.dtype, device, generator, do_classifier_free_guidance, ) # 8. Check that sizes of mask, masked image and latents match num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, masked_image_latents, mask], dim=1) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/pipeline_flax_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import inspect import os from typing import Any, Dict, List, Optional, Union import flax import numpy as np import PIL from flax.core.frozen_dict import FrozenDict from huggingface_hub import snapshot_download from PIL import Image from tqdm.auto import tqdm from ..configuration_utils import ConfigMixin from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging if is_transformers_available(): from transformers import FlaxPreTrainedModel INDEX_FILE = "diffusion_flax_model.bin" logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "FlaxModelMixin": ["save_pretrained", "from_pretrained"], "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"], "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], }, "transformers": { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], "ProcessorMixin": ["save_pretrained", "from_pretrained"], "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], }, } ALL_IMPORTABLE_CLASSES = {} for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) def import_flax_or_no_model(module, class_name): try: # 1. First make sure that if a Flax object is present, import this one class_obj = getattr(module, "Flax" + class_name) except AttributeError: # 2. If this doesn't work, it's not a model and we don't append "Flax" class_obj = getattr(module, class_name) except AttributeError: raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}") return class_obj @flax.struct.dataclass class FlaxImagePipelineOutput(BaseOutput): """ Output class for image pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. """ images: Union[List[PIL.Image.Image], np.ndarray] class FlaxDiffusionPipeline(ConfigMixin): r""" Base class for all models. [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: - enabling/disabling the progress bar for the denoising iteration Class attributes: - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all components of the diffusion pipeline. """ config_name = "model_index.json" def register_modules(self, **kwargs): # import it here to avoid circular import from diffusers import pipelines for name, module in kwargs.items(): if module is None: register_dict = {name: (None, None)} else: # retrieve library library = module.__module__.split(".")[0] # check if the module is a pipeline module pipeline_dir = module.__module__.split(".")[-2] path = module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline # folder so we set the library to module name. if library not in LOADABLE_CLASSES or is_pipeline_module: library = pipeline_dir # retrieve class_name class_name = module.__class__.__name__ register_dict = {name: (library, class_name)} # save model index config self.register_to_config(**register_dict) # set models setattr(self, name, module) def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]): # TODO: handle inference_state """ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. """ self.save_config(save_directory) model_index_dict = dict(self.config) model_index_dict.pop("_class_name") model_index_dict.pop("_diffusers_version") model_index_dict.pop("_module", None) for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) if sub_model is None: # edge case for saving a pipeline with safety_checker=None continue model_cls = sub_model.__class__ save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): library = importlib.import_module(library_name) for base_class, save_load_methods in library_classes.items(): class_candidate = getattr(library, base_class, None) if class_candidate is not None and issubclass(model_cls, class_candidate): # if we found a suitable base class in LOADABLE_CLASSES then grab its save method save_method_name = save_load_methods[0] break if save_method_name is not None: break save_method = getattr(sub_model, save_method_name) expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) if expects_params: save_method( os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name] ) else: save_method(os.path.join(save_directory, pipeline_component_name)) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a Flax diffusion pipeline from pre-trained pipeline weights. The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like `CompVis/ldm-text2im-large-256`. - A path to a *directory* containing pipeline weights saved using [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. dtype (`str` or `jnp.dtype`, *optional*): Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype will be automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. mirror (`str`, *optional*): Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. specify the folder name here. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the specific pipeline class. The overwritten components are then directly passed to the pipelines `__init__` method. See example below for more information. It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"` Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. Examples: ```py >>> from diffusers import FlaxDiffusionPipeline >>> # Download pipeline from huggingface.co and cache. >>> # Requires to be logged in to Hugging Face hub, >>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens) >>> pipeline, params = FlaxDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", ... revision="bf16", ... dtype=jnp.bfloat16, ... ) >>> # Download pipeline, but use a different scheduler >>> from diffusers import FlaxDPMSolverMultistepScheduler >>> model_id = "runwayml/stable-diffusion-v1-5" >>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained( ... model_id, ... subfolder="scheduler", ... ) >>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained( ... model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp ... ) >>> dpm_params["scheduler"] = dpmpp_state ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pt = kwargs.pop("from_pt", False) use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False) dtype = kwargs.pop("dtype", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): config_dict = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, ) # make sure we only download sub-folders and `diffusers` filenames folder_names = [k for k in config_dict.keys() if not k.startswith("_")] allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] # make sure we don't download PyTorch weights, unless when using from_pt ignore_patterns = "*.bin" if not from_pt else [] if cls != FlaxDiffusionPipeline: requested_pipeline_class = cls.__name__ else: requested_pipeline_class = config_dict.get("_class_name", cls.__name__) requested_pipeline_class = ( requested_pipeline_class if requested_pipeline_class.startswith("Flax") else "Flax" + requested_pipeline_class ) user_agent = {"pipeline_class": requested_pipeline_class} user_agent = http_user_agent(user_agent) # download all allow_patterns cached_folder = snapshot_download( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, user_agent=user_agent, ) else: cached_folder = pretrained_model_name_or_path config_dict = cls.load_config(cached_folder) # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it if cls != FlaxDiffusionPipeline: pipeline_class = cls else: diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) class_name = ( config_dict["_class_name"] if config_dict["_class_name"].startswith("Flax") else "Flax" + config_dict["_class_name"] ) pipeline_class = getattr(diffusers_module, class_name) # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} # inference_params params = {} # import it here to avoid circular import from diffusers import pipelines # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): if class_name is None: # edge case for when the pipeline was saved with safety_checker=None init_kwargs[name] = None continue is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None sub_model_should_be_defined = True # if the model is in a pipeline module, then we load it from the pipeline if name in passed_class_obj: # 1. check that passed_class_obj has correct parent class if not is_pipeline_module: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} expected_class_obj = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate if not issubclass(passed_class_obj[name].__class__, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" ) elif passed_class_obj[name] is None: logger.warning( f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" f" that this might lead to problems when using {pipeline_class} and is not recommended." ) sub_model_should_be_defined = False else: logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) # set passed class object loaded_sub_model = passed_class_obj[name] elif is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = import_flax_or_no_model(pipeline_module, class_name) importable_classes = ALL_IMPORTABLE_CLASSES class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) class_obj = import_flax_or_no_model(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} if loaded_sub_model is None and sub_model_should_be_defined: load_method_name = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] load_method = getattr(class_obj, load_method_name) # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): loadable_folder = os.path.join(cached_folder, name) else: loaded_sub_model = cached_folder if issubclass(class_obj, FlaxModelMixin): loaded_sub_model, loaded_params = load_method( loadable_folder, from_pt=from_pt, use_memory_efficient_attention=use_memory_efficient_attention, dtype=dtype, ) params[name] = loaded_params elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): if from_pt: # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) loaded_params = loaded_sub_model.params del loaded_sub_model._params else: loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) params[name] = loaded_params elif issubclass(class_obj, FlaxSchedulerMixin): loaded_sub_model, scheduler_state = load_method(loadable_folder) params[name] = scheduler_state else: loaded_sub_model = load_method(loadable_folder) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) # 4. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) if len(missing_modules) > 0 and missing_modules <= set(passed_modules): for module in missing_modules: init_kwargs[module] = passed_class_obj.get(module, None) elif len(missing_modules) > 0: passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs raise ValueError( f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) model = pipeline_class(**init_kwargs, dtype=dtype) return model, params @staticmethod def _get_signature_keys(obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} return expected_modules, optional_parameters @property def components(self) -> Dict[str, Any]: r""" The `self.components` property can be useful to run different pipelines with the same weights and configurations to not have to re-allocate memory. Examples: ```py >>> from diffusers import ( ... FlaxStableDiffusionPipeline, ... FlaxStableDiffusionImg2ImgPipeline, ... ) >>> text2img = FlaxStableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jnp.bfloat16 ... ) >>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components) ``` Returns: A dictionary containing all the modules needed to initialize the pipeline. """ expected_modules, optional_parameters = self._get_signature_keys(self) components = { k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters } if set(components.keys()) != expected_modules: raise ValueError( f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" f" {expected_modules} to be defined, but {components} are defined." ) return components @staticmethod def numpy_to_pil(images): """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images # TODO: make it compatible with jax.lax def progress_bar(self, iterable): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): raise ValueError( f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) return tqdm(iterable, **self._progress_bar_config) def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs ================================================ FILE: 6DoF/diffusers/pipelines/pipeline_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import fnmatch import importlib import inspect import os import re import sys import warnings from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from huggingface_hub import hf_hub_download, model_info, snapshot_download from packaging import version from requests.exceptions import HTTPError from tqdm.auto import tqdm import diffusers from .. import __version__ from ..configuration_utils import ConfigMixin from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, DIFFUSERS_CACHE, HF_HUB_OFFLINE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, deprecate, get_class_from_dynamic_module, is_accelerate_available, is_accelerate_version, is_compiled_module, is_safetensors_available, is_torch_version, is_transformers_available, logging, numpy_to_pil, ) if is_transformers_available(): import transformers from transformers import PreTrainedModel from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME if is_accelerate_available(): import accelerate INDEX_FILE = "diffusion_pytorch_model.bin" CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" DUMMY_MODULES_FOLDER = "diffusers.utils" TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], "SchedulerMixin": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], }, "transformers": { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"], "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], "ProcessorMixin": ["save_pretrained", "from_pretrained"], "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], }, "onnxruntime.training": { "ORTModule": ["save_pretrained", "from_pretrained"], }, } ALL_IMPORTABLE_CLASSES = {} for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) @dataclass class ImagePipelineOutput(BaseOutput): """ Output class for image pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. """ images: Union[List[PIL.Image.Image], np.ndarray] @dataclass class AudioPipelineOutput(BaseOutput): """ Output class for audio pipelines. Args: audios (`np.ndarray`) List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`. """ audios: np.ndarray def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: """ Checking for safetensors compatibility: - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch files to know which safetensors files are needed. - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. Converting default pytorch serialized filenames to safetensors serialized filenames: - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" extension is replaced with ".safetensors" """ pt_filenames = [] sf_filenames = set() passed_components = passed_components or [] for filename in filenames: _, extension = os.path.splitext(filename) if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: continue if extension == ".bin": pt_filenames.append(filename) elif extension == ".safetensors": sf_filenames.add(filename) for filename in pt_filenames: # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam' path, filename = os.path.split(filename) filename, extension = os.path.splitext(filename) if filename.startswith("pytorch_model"): filename = filename.replace("pytorch_model", "model") else: filename = filename expected_sf_filename = os.path.join(path, filename) expected_sf_filename = f"{expected_sf_filename}.safetensors" if expected_sf_filename not in sf_filenames: logger.warning(f"{expected_sf_filename} not found") return False return True def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: weight_names = [ WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ] if is_transformers_available(): weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] # model_pytorch, diffusion_model_pytorch, ... weight_prefixes = [w.split(".")[0] for w in weight_names] # .bin, .safetensors, ... weight_suffixs = [w.split(".")[-1] for w in weight_names] # -00001-of-00002 transformers_index_format = r"\d{5}-of-\d{5}" if variant is not None: # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors` variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" ) # `text_encoder/pytorch_model.bin.index.fp16.json` variant_index_re = re.compile( rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" ) # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` non_variant_file_re = re.compile( rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" ) # `text_encoder/pytorch_model.bin.index.json` non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") if variant is not None: variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} variant_filenames = variant_weights | variant_indexes else: variant_filenames = set() non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} non_variant_filenames = non_variant_weights | non_variant_indexes # all variant filenames will be used by default usable_filenames = set(variant_filenames) def convert_to_variant(filename): if "index" in filename: variant_filename = filename.replace("index", f"index.{variant}") elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" else: variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" return variant_filename for f in non_variant_filenames: variant_filename = convert_to_variant(f) if variant_filename not in usable_filenames: usable_filenames.add(f) return usable_filenames, variant_filenames def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, variant, revision, model_filenames): info = model_info( pretrained_model_name_or_path, use_auth_token=use_auth_token, revision=None, ) filenames = {sibling.rfilename for sibling in info.siblings} comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] if set(comp_model_filenames) == set(model_filenames): warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", FutureWarning, ) else: warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", FutureWarning, ) def maybe_raise_or_warn( library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module ): """Simple helper method to raise or warn in case incorrect module has been passed""" if not is_pipeline_module: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} expected_class_obj = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. sub_model = passed_class_obj[name] model_cls = sub_model.__class__ if is_compiled_module(sub_model): model_cls = sub_model._orig_mod.__class__ if not issubclass(model_cls, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" ) else: logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module): """Simple helper method to retrieve class object of module as well as potential parent class objects""" if is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) class_obj = getattr(library, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} return class_obj, class_candidates def _get_pipeline_class(class_obj, config, custom_pipeline=None, cache_dir=None, revision=None): if custom_pipeline is not None: if custom_pipeline.endswith(".py"): path = Path(custom_pipeline) # decompose into folder & file file_name = path.name custom_pipeline = path.parent.absolute() else: file_name = CUSTOM_PIPELINE_FILE_NAME return get_class_from_dynamic_module( custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision ) if class_obj != DiffusionPipeline: return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) return getattr(diffusers_module, config["_class_name"]) def load_sub_model( library_name: str, class_name: str, importable_classes: List[Any], pipelines: Any, is_pipeline_module: bool, pipeline_class: Any, torch_dtype: torch.dtype, provider: Any, sess_options: Any, device_map: Optional[Union[Dict[str, torch.device], str]], max_memory: Optional[Dict[Union[int, str], Union[int, str]]], offload_folder: Optional[Union[str, os.PathLike]], offload_state_dict: bool, model_variants: Dict[str, str], name: str, from_flax: bool, variant: str, low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], ): """Helper method to load the module `name` from `library_name` and `class_name`""" # retrieve class candidates class_obj, class_candidates = get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module ) load_method_name = None # retrive load method name for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] # if load method name is None, then we have a dummy module -> raise Error if load_method_name is None: none_module = class_obj.__module__ is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( TRANSFORMERS_DUMMY_MODULES_FOLDER ) if is_dummy_path and "dummy" in none_module: # call class_obj for nice error message of missing requirements class_obj() raise ValueError( f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." ) load_method = getattr(class_obj, load_method_name) # add kwargs to loading method loading_kwargs = {} if issubclass(class_obj, torch.nn.Module): loading_kwargs["torch_dtype"] = torch_dtype if issubclass(class_obj, diffusers.OnnxRuntimeModel): loading_kwargs["provider"] = provider loading_kwargs["sess_options"] = sess_options is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) if is_transformers_available(): transformers_version = version.parse(version.parse(transformers.__version__).base_version) else: transformers_version = "N/A" is_transformers_model = ( is_transformers_available() and issubclass(class_obj, PreTrainedModel) and transformers_version >= version.parse("4.20.0") ) # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading. if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map loading_kwargs["max_memory"] = max_memory loading_kwargs["offload_folder"] = offload_folder loading_kwargs["offload_state_dict"] = offload_state_dict loading_kwargs["variant"] = model_variants.pop(name, None) if from_flax: loading_kwargs["from_flax"] = True # the following can be deleted once the minimum required `transformers` version # is higher than 4.27 if ( is_transformers_model and loading_kwargs["variant"] is not None and transformers_version < version.parse("4.27.0") ): raise ImportError( f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" ) elif is_transformers_model and loading_kwargs["variant"] is None: loading_kwargs.pop("variant") # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` if not (from_flax and is_transformers_model): loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage else: loading_kwargs["low_cpu_mem_usage"] = False # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) else: # else load from the root directory loaded_sub_model = load_method(cached_folder, **loading_kwargs) return loaded_sub_model class DiffusionPipeline(ConfigMixin): r""" Base class for all pipelines. [`DiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and provides methods for loading, downloading and saving models. It also includes methods to: - move all PyTorch modules to the device of your choice - enabling/disabling the progress bar for the denoising iteration Class attributes: - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the diffusion pipeline's components. - **_optional_components** (List[`str`]) -- List of all optional components that don't have to be passed to the pipeline to function (should be overridden by subclasses). """ config_name = "model_index.json" _optional_components = [] def register_modules(self, **kwargs): # import it here to avoid circular import from diffusers import pipelines for name, module in kwargs.items(): # retrieve library if module is None: register_dict = {name: (None, None)} else: # register the config from the original module, not the dynamo compiled one if is_compiled_module(module): not_compiled_module = module._orig_mod else: not_compiled_module = module library = not_compiled_module.__module__.split(".")[0] # check if the module is a pipeline module module_path_items = not_compiled_module.__module__.split(".") pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None path = not_compiled_module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline # folder so we set the library to module name. if is_pipeline_module: library = pipeline_dir elif library not in LOADABLE_CLASSES: library = not_compiled_module.__module__ # retrieve class_name class_name = not_compiled_module.__class__.__name__ register_dict = {name: (library, class_name)} # save model index config self.register_to_config(**register_dict) # set models setattr(self, name, module) def __setattr__(self, name: str, value: Any): if name in self.__dict__ and hasattr(self.config, name): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): if value is not None and self.config[name][0] is not None: class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) else: class_library_tuple = (None, None) self.register_to_config(**{name: class_library_tuple}) else: self.register_to_config(**{name: value}) super().__setattr__(name, value) def save_pretrained( self, save_directory: Union[str, os.PathLike], safe_serialization: bool = False, variant: Optional[str] = None, ): """ Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading method. The pipeline is easily reloaded using the [`~DiffusionPipeline.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): Directory to save a pipeline to. Will be created if it doesn't exist. safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. """ model_index_dict = dict(self.config) model_index_dict.pop("_class_name", None) model_index_dict.pop("_diffusers_version", None) model_index_dict.pop("_module", None) expected_modules, optional_kwargs = self._get_signature_keys(self) def is_saveable_module(name, value): if name not in expected_modules: return False if name in self._optional_components and value[0] is None: return False return True model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) model_cls = sub_model.__class__ # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. if is_compiled_module(sub_model): sub_model = sub_model._orig_mod model_cls = sub_model.__class__ save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): if library_name in sys.modules: library = importlib.import_module(library_name) else: logger.info( f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}" ) for base_class, save_load_methods in library_classes.items(): class_candidate = getattr(library, base_class, None) if class_candidate is not None and issubclass(model_cls, class_candidate): # if we found a suitable base class in LOADABLE_CLASSES then grab its save method save_method_name = save_load_methods[0] break if save_method_name is not None: break if save_method_name is None: logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.") # make sure that unsaveable components are not tried to be loaded afterward self.register_to_config(**{pipeline_component_name: (None, None)}) continue save_method = getattr(sub_model, save_method_name) # Call the save method with the argument safe_serialization only if it's supported save_method_signature = inspect.signature(save_method) save_method_accept_safe = "safe_serialization" in save_method_signature.parameters save_method_accept_variant = "variant" in save_method_signature.parameters save_kwargs = {} if save_method_accept_safe: save_kwargs["safe_serialization"] = safe_serialization if save_method_accept_variant: save_kwargs["variant"] = variant save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) # finally save the config self.save_config(save_directory) def to( self, torch_device: Optional[Union[str, torch.device]] = None, torch_dtype: Optional[torch.dtype] = None, silence_dtype_warnings: bool = False, ): if torch_device is None and torch_dtype is None: return self # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. def module_is_sequentially_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"): return False return hasattr(module, "_hf_hook") and not isinstance( module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook) ) def module_is_offloaded(module): if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"): return False return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda": raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." ) # Display a warning in this case (the operation succeeds but the benefits are lost) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) if pipeline_is_offloaded and torch.device(torch_device).type == "cuda": logger.warning( f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." ) module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit if is_loaded_in_8bit and torch_dtype is not None: logger.warning( f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision." ) if is_loaded_in_8bit and torch_device is not None: logger.warning( f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}." ) else: module.to(torch_device, torch_dtype) if ( module.dtype == torch.float16 and str(torch_device) in ["cpu"] and not silence_dtype_warnings and not is_offloaded ): logger.warning( "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" " is not recommended to move them to `cpu` as running them will fail. Please make" " sure to use an accelerator to run the pipeline in inference, due to the lack of" " support for`float16` operations on this device in PyTorch. Please, remove the" " `torch_dtype=torch.float16` argument, or use another device for inference." ) return self @property def device(self) -> torch.device: r""" Returns: `torch.device`: The torch device on which the pipeline is located. """ module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] for module in modules: return module.device return torch.device("cpu") @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights. The pipeline is set in evaluation mode (`model.eval()`) by default. If you get the error message below, you need to finetune the weights for your downstream task: ``` Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. ``` Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. custom_pipeline (`str`, *optional*): 🧪 This is an experimental feature and may change in the future. Can be either: - A string, the *repo id* (for example `hf-internal-testing/diffusers-dummy-pipeline`) of a custom pipeline hosted on the Hub. The repository must contain a file called pipeline.py that defines the custom pipeline. - A string, the *file name* of a community pipeline hosted on GitHub under [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the current main branch of GitHub. - A path to a directory (`./my_pipeline_directory/`) containing a custom pipeline. The directory must contain a file called `pipeline.py` that defines the custom pipeline. For more information on how to load and create custom pipelines, please have a look at [Loading and Adding Custom Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. custom_revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): A map that specifies where each submodule should go. It doesn’t need to be defined for each parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): A dictionary device identifier for the maximum memory. Will default to the maximum memory available for each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): The path to offload weights if device_map contains the value `"disk"`. offload_state_dict (`bool`, *optional*): If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading only loading the pretrained weights and not initializing the weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example below for more information. variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `huggingface-cli login`. Examples: ```py >>> from diffusers import DiffusionPipeline >>> # Download pipeline from huggingface.co and cache. >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") >>> # Download pipeline that requires an authorization token >>> # For more information on access tokens, please refer to this section >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # Use a different scheduler >>> from diffusers import LMSDiscreteScheduler >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) >>> pipeline.scheduler = scheduler ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_flax = kwargs.pop("from_flax", False) torch_dtype = kwargs.pop("torch_dtype", None) custom_pipeline = kwargs.pop("custom_pipeline", None) custom_revision = kwargs.pop("custom_revision", None) provider = kwargs.pop("provider", None) sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) offload_folder = kwargs.pop("offload_folder", None) offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): cached_folder = cls.download( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, force_download=force_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, from_flax=from_flax, use_safetensors=use_safetensors, custom_pipeline=custom_pipeline, custom_revision=custom_revision, variant=variant, **kwargs, ) else: cached_folder = pretrained_model_name_or_path config_dict = cls.load_config(cached_folder) # pop out "_ignore_files" as it is only needed for download config_dict.pop("_ignore_files", None) # 2. Define which model components should load variants # We retrieve the information by matching whether variant # model checkpoints exist in the subfolders model_variants = {} if variant is not None: for folder in os.listdir(cached_folder): folder_path = os.path.join(cached_folder, folder) is_folder = os.path.isdir(folder_path) and folder in config_dict variant_exists = is_folder and any( p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) ) if variant_exists: model_variants[folder] = variant # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it pipeline_class = _get_pipeline_class( cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision ) # DEPRECATED: To be removed in 1.0.0 if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( version.parse(config_dict["_diffusers_version"]).base_version ) <= version.parse("0.5.1"): from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy pipeline_class = StableDiffusionInpaintPipelineLegacy deprecation_message = ( "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" f" checkpoint {pretrained_model_name_or_path} to the format of" " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." ) deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) # 4. Define expected modules given pipeline signature # and define non-None initialized modules (=`init_kwargs`) # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) # define init kwargs init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} init_kwargs = {**init_kwargs, **passed_pipe_kwargs} # remove `null` components def load_module(name, value): if value[0] is None: return False if name in passed_class_obj and passed_class_obj[name] is None: return False return True init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError( "The safety checker cannot be automatically loaded when loading weights `from_flax`." " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker" " separately if you need it." ) # 5. Throw nice warnings / errors for fast accelerate loading if len(unused_kwargs) > 0: logger.warning( f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." ) if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" " install accelerate\n```\n." ) if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" " `device_map=None`." ) if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" " `low_cpu_mem_usage=False`." ) if low_cpu_mem_usage is False and device_map is not None: raise ValueError( f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) # import it here to avoid circular import from diffusers import pipelines # 6. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names if class_name.startswith("Flax"): class_name = class_name[4:] # 6.2 Define all importable classes is_pipeline_module = hasattr(pipelines, library_name) importable_classes = ALL_IMPORTABLE_CLASSES loaded_sub_model = None # 6.3 Use passed sub model or load class_name from library_name if name in passed_class_obj: # if the model is in a pipeline module, then we load it from the pipeline # check that passed_class_obj has correct parent class maybe_raise_or_warn( library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module ) loaded_sub_model = passed_class_obj[name] else: # load sub model loaded_sub_model = load_sub_model( library_name=library_name, class_name=class_name, importable_classes=importable_classes, pipelines=pipelines, is_pipeline_module=is_pipeline_module, pipeline_class=pipeline_class, torch_dtype=torch_dtype, provider=provider, sess_options=sess_options, device_map=device_map, max_memory=max_memory, offload_folder=offload_folder, offload_state_dict=offload_state_dict, model_variants=model_variants, name=name, from_flax=from_flax, variant=variant, low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, ) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) # 7. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) passed_modules = list(passed_class_obj.keys()) optional_modules = pipeline_class._optional_components if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): for module in missing_modules: init_kwargs[module] = passed_class_obj.get(module, None) elif len(missing_modules) > 0: passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs raise ValueError( f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) # 8. Instantiate the pipeline model = pipeline_class(**init_kwargs) return model @classmethod def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: r""" Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights. Parameters: pretrained_model_name (`str` or `os.PathLike`, *optional*): A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. custom_pipeline (`str`, *optional*): Can be either: - A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted on the Hub. The repository must contain a file called `pipeline.py` that defines the custom pipeline. - A string, the *file name* of a community pipeline hosted on GitHub under [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the current `main` branch of GitHub. - A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory must contain a file called `pipeline.py` that defines the custom pipeline. 🧪 This is an experimental feature and may change in the future. For more information on how to load and create custom pipelines, take a look at [How to contribute a community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. custom_revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub. mirror (`str`, *optional*): Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. Returns: `os.PathLike`: A path to the downloaded pipeline. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_flax = kwargs.pop("from_flax", False) custom_pipeline = kwargs.pop("custom_pipeline", None) custom_revision = kwargs.pop("custom_revision", None) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors" ) allow_pickle = False if use_safetensors is None: use_safetensors = is_safetensors_available() allow_pickle = True pipeline_is_cached = False allow_patterns = None ignore_patterns = None if not local_files_only: try: info = model_info( pretrained_model_name, use_auth_token=use_auth_token, revision=revision, ) except HTTPError as e: logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") local_files_only = True if not local_files_only: config_file = hf_hub_download( pretrained_model_name, cls.config_name, cache_dir=cache_dir, revision=revision, proxies=proxies, force_download=force_download, resume_download=resume_download, use_auth_token=use_auth_token, ) config_dict = cls._dict_from_json_file(config_file) ignore_filenames = config_dict.pop("_ignore_files", []) # retrieve all folder_names that contain relevant files folder_names = [k for k, v in config_dict.items() if isinstance(v, list)] filenames = {sibling.rfilename for sibling in info.siblings} model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) if len(variant_filenames) == 0 and variant is not None: deprecation_message = ( f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" "if such variant modeling files are not available. Doing so will lead to an error in v0.22.0 as defaulting to non-variant" "modeling files is deprecated." ) deprecate("no variant default", "0.22.0", deprecation_message, standard_warn=False) # remove ignored filenames model_filenames = set(model_filenames) - set(ignore_filenames) variant_filenames = set(variant_filenames) - set(ignore_filenames) # if the whole pipeline is cached we don't have to ping the Hub if revision in DEPRECATED_REVISION_ARGS and version.parse( version.parse(__version__).base_version ) >= version.parse("0.20.0"): warn_deprecated_model_variant( pretrained_model_name, use_auth_token, variant, revision, model_filenames ) model_folder_names = {os.path.split(f)[0] for f in model_filenames} # all filenames compatible with variant will be added allow_patterns = list(model_filenames) # allow all patterns from non-model folders # this enables downloading schedulers, tokenizers, ... allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] # also allow downloading config.json files with the model allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] allow_patterns += [ SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name, CUSTOM_PIPELINE_FILE_NAME, ] # retrieve passed components that should not be downloaded pipeline_class = _get_pipeline_class( cls, config_dict, custom_pipeline=custom_pipeline, cache_dir=cache_dir, revision=custom_revision ) expected_components, _ = cls._get_signature_keys(pipeline_class) passed_components = [k for k in expected_components if k in kwargs] if ( use_safetensors and not allow_pickle and not is_safetensors_compatible( model_filenames, variant=variant, passed_components=passed_components ) ): raise EnvironmentError( f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})" ) if from_flax: ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] elif use_safetensors and is_safetensors_compatible( model_filenames, variant=variant, passed_components=passed_components ): ignore_patterns = ["*.bin", "*.msgpack"] safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} if ( len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames ): logger.warn( f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." ) else: ignore_patterns = ["*.safetensors", "*.msgpack"] bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: logger.warn( f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." ) # Don't download any objects that are passed allow_patterns = [ p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components) ] # Don't download index files of forbidden patterns either ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns] re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns] re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns] expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)] expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)] snapshot_folder = Path(config_file).parent pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files) if pipeline_is_cached and not force_download: # if the pipeline is cached, we can directly return it # else call snapshot_download return snapshot_folder user_agent = {"pipeline_class": cls.__name__} if custom_pipeline is not None and not custom_pipeline.endswith(".py"): user_agent["custom_pipeline"] = custom_pipeline # download all allow_patterns - ignore_patterns cached_folder = snapshot_download( pretrained_model_name, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, user_agent=user_agent, ) return cached_folder @staticmethod def _get_signature_keys(obj): parameters = inspect.signature(obj.__init__).parameters required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) expected_modules = set(required_parameters.keys()) - {"self"} return expected_modules, optional_parameters @property def components(self) -> Dict[str, Any]: r""" The `self.components` property can be useful to run different pipelines with the same weights and configurations without reallocating additional memory. Returns (`dict`): A dictionary containing all the modules needed to initialize the pipeline. Examples: ```py >>> from diffusers import ( ... StableDiffusionPipeline, ... StableDiffusionImg2ImgPipeline, ... StableDiffusionInpaintPipeline, ... ) >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) ``` """ expected_modules, optional_parameters = self._get_signature_keys(self) components = { k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters } if set(components.keys()) != expected_modules: raise ValueError( f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" f" {expected_modules} to be defined, but {components.keys()} are defined." ) return components @staticmethod def numpy_to_pil(images): """ Convert a NumPy image or a batch of images to a PIL image. """ return numpy_to_pil(images) def progress_bar(self, iterable=None, total=None): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): raise ValueError( f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) if iterable is not None: return tqdm(iterable, **self._progress_bar_config) elif total is not None: return tqdm(total=total, **self._progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): r""" Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed up during training is not guaranteed. ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes precedent. Parameters: attention_op (`Callable`, *optional*): Override the default `None` operator for use as `op` argument to the [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention) function of xFormers. Examples: ```py >>> import torch >>> from diffusers import DiffusionPipeline >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) >>> # Workaround for not accepting attention shape using VAE for Flash Attention >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None) ``` """ self.set_use_memory_efficient_attention_xformers(True, attention_op) def disable_xformers_memory_efficient_attention(self): r""" Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). """ self.set_use_memory_efficient_attention_xformers(False) def set_use_memory_efficient_attention_xformers( self, valid: bool, attention_op: Optional[Callable] = None ) -> None: # Recursively walk through all the children. # Any children which exposes the set_use_memory_efficient_attention_xformers method # gets the message def fn_recursive_set_mem_eff(module: torch.nn.Module): if hasattr(module, "set_use_memory_efficient_attention_xformers"): module.set_use_memory_efficient_attention_xformers(valid, attention_op) for child in module.children(): fn_recursive_set_mem_eff(child) module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] for module in modules: fn_recursive_set_mem_eff(module) def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful to save some memory in exchange for a small speed decrease. Args: slice_size (`str` or `int`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ self.set_attention_slice(slice_size) def disable_attention_slicing(self): r""" Disable sliced attention computation. If `enable_attention_slicing` was previously called, attention is computed in one step. """ # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) def set_attention_slice(self, slice_size: Optional[int]): module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")] for module in modules: module.set_attention_slice(slice_size) ================================================ FILE: 6DoF/diffusers/pipelines/pndm/__init__.py ================================================ from .pipeline_pndm import PNDMPipeline ================================================ FILE: 6DoF/diffusers/pipelines/pndm/pipeline_pndm.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...models import UNet2DModel from ...schedulers import PNDMScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class PNDMPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. """ unet: UNet2DModel scheduler: PNDMScheduler def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): super().__init__() scheduler = PNDMScheduler.from_config(scheduler.config) self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 50, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: batch_size (`int`, `optional`, defaults to 1): The number of images to generate. num_inference_steps (`int`, `optional`, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator`, `optional`): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # For more information on the sampling method you can take a look at Algorithm 2 of # the official paper: https://arxiv.org/pdf/2202.09778.pdf # Sample gaussian noise to begin loop image = randn_tensor( (batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), generator=generator, device=self.device, ) self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): model_output = self.unet(image, t).sample image = self.scheduler.step(model_output, t, image).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/repaint/__init__.py ================================================ from .pipeline_repaint import RePaintPipeline ================================================ FILE: 6DoF/diffusers/pipelines/repaint/pipeline_repaint.py ================================================ # Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import List, Optional, Tuple, Union import numpy as np import PIL import torch from ...models import UNet2DModel from ...schedulers import RePaintScheduler from ...utils import PIL_INTERPOLATION, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def _preprocess_mask(mask: Union[List, PIL.Image.Image, torch.Tensor]): if isinstance(mask, torch.Tensor): return mask elif isinstance(mask, PIL.Image.Image): mask = [mask] if isinstance(mask[0], PIL.Image.Image): w, h = mask[0].size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = [np.array(m.convert("L").resize((w, h), resample=PIL_INTERPOLATION["nearest"]))[None, :] for m in mask] mask = np.concatenate(mask, axis=0) mask = mask.astype(np.float32) / 255.0 mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) elif isinstance(mask[0], torch.Tensor): mask = torch.cat(mask, dim=0) return mask class RePaintPipeline(DiffusionPipeline): unet: UNet2DModel scheduler: RePaintScheduler def __init__(self, unet, scheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, image: Union[torch.Tensor, PIL.Image.Image], mask_image: Union[torch.Tensor, PIL.Image.Image], num_inference_steps: int = 250, eta: float = 0.0, jump_length: int = 10, jump_n_sample: int = 10, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: image (`torch.FloatTensor` or `PIL.Image.Image`): The original image to inpaint on. mask_image (`torch.FloatTensor` or `PIL.Image.Image`): The mask_image where 0.0 values define which part of the original image to inpaint (change). num_inference_steps (`int`, *optional*, defaults to 1000): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. eta (`float`): The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM and 1.0 is DDPM scheduler respectively. jump_length (`int`, *optional*, defaults to 10): The number of steps taken forward in time before going backward in time for a single jump ("j" in RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. jump_n_sample (`int`, *optional*, defaults to 10): The number of times we will make forward time jump for a given chosen time sample. Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ original_image = image original_image = _preprocess_image(original_image) original_image = original_image.to(device=self.device, dtype=self.unet.dtype) mask_image = _preprocess_mask(mask_image) mask_image = mask_image.to(device=self.device, dtype=self.unet.dtype) batch_size = original_image.shape[0] # sample gaussian noise to begin the loop if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) image_shape = original_image.shape image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype) # set step values self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) self.scheduler.eta = eta t_last = self.scheduler.timesteps[0] + 1 generator = generator[0] if isinstance(generator, list) else generator for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): if t < t_last: # predict the noise residual model_output = self.unet(image, t).sample # compute previous image: x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample else: # compute the reverse: x_t-1 -> x_t image = self.scheduler.undo_step(image, t_last, generator) t_last = t image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/score_sde_ve/__init__.py ================================================ from .pipeline_score_sde_ve import ScoreSdeVePipeline ================================================ FILE: 6DoF/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...models import UNet2DModel from ...schedulers import ScoreSdeVeScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class ScoreSdeVePipeline(DiffusionPipeline): r""" Parameters: This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]): The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image. """ unet: UNet2DModel scheduler: ScoreSdeVeScheduler def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 2000, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) model = self.unet sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma sample = sample.to(self.device) self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps) for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) # correction step for _ in range(self.scheduler.config.correct_steps): model_output = self.unet(sample, sigma_t).sample sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample # prediction step model_output = model(sample, sigma_t).sample output = self.scheduler.step_pred(model_output, t, sample, generator=generator) sample, sample_mean = output.prev_sample, output.prev_sample_mean sample = sample_mean.clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": sample = self.numpy_to_pil(sample) if not return_dict: return (sample,) return ImagePipelineOutput(images=sample) ================================================ FILE: 6DoF/diffusers/pipelines/semantic_stable_diffusion/__init__.py ================================================ from dataclasses import dataclass from enum import Enum from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import BaseOutput, is_torch_available, is_transformers_available @dataclass class SemanticStableDiffusionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] if is_transformers_available() and is_torch_available(): from .pipeline_semantic_stable_diffusion import SemanticStableDiffusionPipeline ================================================ FILE: 6DoF/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py ================================================ import inspect import warnings from itertools import repeat from typing import Callable, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import SemanticStableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import SemanticStableDiffusionPipeline >>> pipe = SemanticStableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> out = pipe( ... prompt="a photo of the face of a woman", ... num_images_per_prompt=1, ... guidance_scale=7, ... editing_prompt=[ ... "smiling, smile", # Concepts to apply ... "glasses, wearing glasses", ... "curls, wavy hair, curly hair", ... "beard, full beard, mustache", ... ], ... reverse_editing_direction=[ ... False, ... False, ... False, ... False, ... ], # Direction of guidance i.e. increase all concepts ... edit_warmup_steps=[10, 10, 10, 10], # Warmup period for each concept ... edit_guidance_scale=[4, 5, 5, 5.4], # Guidance scale for each concept ... edit_threshold=[ ... 0.99, ... 0.975, ... 0.925, ... 0.96, ... ], # Threshold for each concept. Threshold equals the percentile of the latent space that will be discarded. I.e. threshold=0.99 uses 1% of the latent dimensions ... edit_momentum_scale=0.3, # Momentum scale that will be added to the latent guidance ... edit_mom_beta=0.6, # Momentum beta ... edit_weights=[1, 1, 1, 1, 1], # Weights of the individual concepts against each other ... ) >>> image = out.images[0] ``` """ class SemanticStableDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation with latent editing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) This model builds on the implementation of ['StableDiffusionPipeline'] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`Q16SafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, editing_prompt: Optional[Union[str, List[str]]] = None, editing_prompt_embeddings: Optional[torch.Tensor] = None, reverse_editing_direction: Optional[Union[bool, List[bool]]] = False, edit_guidance_scale: Optional[Union[float, List[float]]] = 5, edit_warmup_steps: Optional[Union[int, List[int]]] = 10, edit_cooldown_steps: Optional[Union[int, List[int]]] = None, edit_threshold: Optional[Union[float, List[float]]] = 0.9, edit_momentum_scale: Optional[float] = 0.1, edit_mom_beta: Optional[float] = 0.4, edit_weights: Optional[List[float]] = None, sem_guidance: Optional[List[torch.Tensor]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. editing_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to use for Semantic guidance. Semantic guidance is disabled by setting `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`. editing_prompt_embeddings (`torch.Tensor>`, *optional*): Pre-computed embeddings to use for semantic guidance. Guidance direction of embedding should be specified via `reverse_editing_direction`. reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`): Whether the corresponding prompt in `editing_prompt` should be increased or decreased. edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5): Guidance scale for semantic guidance. If provided as list values should correspond to `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 6 of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10): Number of diffusion steps (for each prompt) for which semantic guidance will not be applied. Momentum will still be calculated for those steps and applied once all warmup periods are over. `edit_warmup_steps` is defined as `delta` (δ) of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`): Number of diffusion steps (for each prompt) after which semantic guidance will no longer be applied. edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9): Threshold of semantic guidance. edit_momentum_scale (`float`, *optional*, defaults to 0.1): Scale of the momentum to be added to the semantic guidance at each diffusion step. If set to 0.0 momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller than `sld_warmup_steps`. Momentum will only be added to latent guidance once all warmup periods are finished. `edit_momentum_scale` is defined as `s_m` of equation 7 of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_mom_beta (`float`, *optional*, defaults to 0.4): Defines how semantic guidance momentum builds up. `edit_mom_beta` indicates how much of the previous momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller than `edit_warmup_steps`. `edit_mom_beta` is defined as `beta_m` (β) of equation 8 of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). edit_weights (`List[float]`, *optional*, defaults to `None`): Indicates how much each individual concept should influence the overall guidance. If no weights are provided all concepts are applied equally. `edit_mom_beta` is defined as `g_i` of equation 9 of [SEGA Paper](https://arxiv.org/pdf/2301.12247.pdf). sem_guidance (`List[torch.Tensor]`, *optional*): List of pre-generated guidance vectors to be applied at generation. Length of the list has to correspond to `num_inference_steps`. Returns: [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.semantic_stable_diffusion.SemanticStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) if editing_prompt: enable_edit_guidance = True if isinstance(editing_prompt, str): editing_prompt = [editing_prompt] enabled_editing_prompts = len(editing_prompt) elif editing_prompt_embeddings is not None: enable_edit_guidance = True enabled_editing_prompts = editing_prompt_embeddings.shape[0] else: enabled_editing_prompts = 0 enable_edit_guidance = False # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) text_input_ids = text_inputs.input_ids if text_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if enable_edit_guidance: # get safety text embeddings if editing_prompt_embeddings is None: edit_concepts_input = self.tokenizer( [x for item in editing_prompt for x in repeat(item, batch_size)], padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) edit_concepts_input_ids = edit_concepts_input.input_ids if edit_concepts_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode( edit_concepts_input_ids[:, self.tokenizer.model_max_length :] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0] else: edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed_edit, seq_len_edit, _ = edit_concepts.shape edit_concepts = edit_concepts.repeat(1, num_images_per_prompt, 1) edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if enable_edit_guidance: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts]) else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the initial random noise unless the user supplied it # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, text_embeddings.dtype, self.device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Initialize edit_momentum to None edit_momentum = None self.uncond_estimates = None self.text_estimates = None self.edit_estimates = None self.sem_guidance = None for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * (2 + enabled_editing_prompts)) if do_classifier_free_guidance else latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_out = noise_pred.chunk(2 + enabled_editing_prompts) # [b,4, 64, 64] noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] noise_pred_edit_concepts = noise_pred_out[2:] # default text guidance noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond) # noise_guidance = (noise_pred_text - noise_pred_edit_concepts[0]) if self.uncond_estimates is None: self.uncond_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_uncond.shape)) self.uncond_estimates[i] = noise_pred_uncond.detach().cpu() if self.text_estimates is None: self.text_estimates = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) self.text_estimates[i] = noise_pred_text.detach().cpu() if self.edit_estimates is None and enable_edit_guidance: self.edit_estimates = torch.zeros( (num_inference_steps + 1, len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape) ) if self.sem_guidance is None: self.sem_guidance = torch.zeros((num_inference_steps + 1, *noise_pred_text.shape)) if edit_momentum is None: edit_momentum = torch.zeros_like(noise_guidance) if enable_edit_guidance: concept_weights = torch.zeros( (len(noise_pred_edit_concepts), noise_guidance.shape[0]), device=self.device, dtype=noise_guidance.dtype, ) noise_guidance_edit = torch.zeros( (len(noise_pred_edit_concepts), *noise_guidance.shape), device=self.device, dtype=noise_guidance.dtype, ) # noise_guidance_edit = torch.zeros_like(noise_guidance) warmup_inds = [] for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts): self.edit_estimates[i, c] = noise_pred_edit_concept if isinstance(edit_guidance_scale, list): edit_guidance_scale_c = edit_guidance_scale[c] else: edit_guidance_scale_c = edit_guidance_scale if isinstance(edit_threshold, list): edit_threshold_c = edit_threshold[c] else: edit_threshold_c = edit_threshold if isinstance(reverse_editing_direction, list): reverse_editing_direction_c = reverse_editing_direction[c] else: reverse_editing_direction_c = reverse_editing_direction if edit_weights: edit_weight_c = edit_weights[c] else: edit_weight_c = 1.0 if isinstance(edit_warmup_steps, list): edit_warmup_steps_c = edit_warmup_steps[c] else: edit_warmup_steps_c = edit_warmup_steps if isinstance(edit_cooldown_steps, list): edit_cooldown_steps_c = edit_cooldown_steps[c] elif edit_cooldown_steps is None: edit_cooldown_steps_c = i + 1 else: edit_cooldown_steps_c = edit_cooldown_steps if i >= edit_warmup_steps_c: warmup_inds.append(c) if i >= edit_cooldown_steps_c: noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept) continue noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3)) tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3)) tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts) if reverse_editing_direction_c: noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1 concept_weights[c, :] = tmp_weights noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c # torch.quantile function expects float32 if noise_guidance_edit_tmp.dtype == torch.float32: tmp = torch.quantile( torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2), edit_threshold_c, dim=2, keepdim=False, ) else: tmp = torch.quantile( torch.abs(noise_guidance_edit_tmp).flatten(start_dim=2).to(torch.float32), edit_threshold_c, dim=2, keepdim=False, ).to(noise_guidance_edit_tmp.dtype) noise_guidance_edit_tmp = torch.where( torch.abs(noise_guidance_edit_tmp) >= tmp[:, :, None, None], noise_guidance_edit_tmp, torch.zeros_like(noise_guidance_edit_tmp), ) noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp warmup_inds = torch.tensor(warmup_inds).to(self.device) if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: concept_weights = concept_weights.to("cpu") # Offload to cpu noise_guidance_edit = noise_guidance_edit.to("cpu") concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds) concept_weights_tmp = torch.where( concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp ) concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) noise_guidance_edit_tmp = torch.index_select( noise_guidance_edit.to(self.device), 0, warmup_inds ) noise_guidance_edit_tmp = torch.einsum( "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp ) noise_guidance_edit_tmp = noise_guidance_edit_tmp noise_guidance = noise_guidance + noise_guidance_edit_tmp self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu() del noise_guidance_edit_tmp del concept_weights_tmp concept_weights = concept_weights.to(self.device) noise_guidance_edit = noise_guidance_edit.to(self.device) concept_weights = torch.where( concept_weights < 0, torch.zeros_like(concept_weights), concept_weights ) concept_weights = torch.nan_to_num(concept_weights) noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit if warmup_inds.shape[0] == len(noise_pred_edit_concepts): noise_guidance = noise_guidance + noise_guidance_edit self.sem_guidance[i] = noise_guidance_edit.detach().cpu() if sem_guidance is not None: edit_guidance = sem_guidance[i].to(self.device) noise_guidance = noise_guidance + edit_guidance noise_pred = noise_pred_uncond + noise_guidance # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/shap_e/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline else: from .camera import create_pan_cameras from .pipeline_shap_e import ShapEPipeline from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline from .renderer import ( BoundingBoxVolume, ImportanceRaySampler, MLPNeRFModelOutput, MLPNeRSTFModel, ShapEParamsProjModel, ShapERenderer, StratifiedRaySampler, VoidNeRFModel, ) ================================================ FILE: 6DoF/diffusers/pipelines/shap_e/camera.py ================================================ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Tuple import numpy as np import torch @dataclass class DifferentiableProjectiveCamera: """ Implements a batch, differentiable, standard pinhole camera """ origin: torch.Tensor # [batch_size x 3] x: torch.Tensor # [batch_size x 3] y: torch.Tensor # [batch_size x 3] z: torch.Tensor # [batch_size x 3] width: int height: int x_fov: float y_fov: float shape: Tuple[int] def __post_init__(self): assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2 def resolution(self): return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) def fov(self): return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) def get_image_coords(self) -> torch.Tensor: """ :return: coords of shape (width * height, 2) """ pixel_indices = torch.arange(self.height * self.width) coords = torch.stack( [ pixel_indices % self.width, torch.div(pixel_indices, self.width, rounding_mode="trunc"), ], axis=1, ) return coords @property def camera_rays(self): batch_size, *inner_shape = self.shape inner_batch_size = int(np.prod(inner_shape)) coords = self.get_image_coords() coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) rays = self.get_camera_rays(coords) rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3) return rays def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor: batch_size, *shape, n_coords = coords.shape assert n_coords == 2 assert batch_size == self.origin.shape[0] flat = coords.view(batch_size, -1, 2) res = self.resolution() fov = self.fov() fracs = (flat.float() / (res - 1)) * 2 - 1 fracs = fracs * torch.tan(fov / 2) fracs = fracs.view(batch_size, -1, 2) directions = ( self.z.view(batch_size, 1, 3) + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] ) directions = directions / directions.norm(dim=-1, keepdim=True) rays = torch.stack( [ torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]), directions, ], dim=2, ) return rays.view(batch_size, *shape, 2, 3) def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": """ Creates a new camera for the resized view assuming the aspect ratio does not change. """ assert width * self.height == height * self.width, "The aspect ratio should not change." return DifferentiableProjectiveCamera( origin=self.origin, x=self.x, y=self.y, z=self.z, width=width, height=height, x_fov=self.x_fov, y_fov=self.y_fov, ) def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera: origins = [] xs = [] ys = [] zs = [] for theta in np.linspace(0, 2 * np.pi, num=20): z = np.array([np.sin(theta), np.cos(theta), -0.5]) z /= np.sqrt(np.sum(z**2)) origin = -z * 4 x = np.array([np.cos(theta), -np.sin(theta), 0.0]) y = np.cross(z, x) origins.append(origin) xs.append(x) ys.append(y) zs.append(z) return DifferentiableProjectiveCamera( origin=torch.from_numpy(np.stack(origins, axis=0)).float(), x=torch.from_numpy(np.stack(xs, axis=0)).float(), y=torch.from_numpy(np.stack(ys, axis=0)).float(), z=torch.from_numpy(np.stack(zs, axis=0)).float(), width=size, height=size, x_fov=0.7, y_fov=0.7, shape=(1, len(xs)), ) ================================================ FILE: 6DoF/diffusers/pipelines/shap_e/pipeline_shap_e.py ================================================ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPTextModelWithProjection, CLIPTokenizer from ...models import PriorTransformer from ...pipelines import DiffusionPipeline from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from .renderer import ShapERenderer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.utils import export_to_gif >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> repo = "openai/shap-e" >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> guidance_scale = 15.0 >>> prompt = "a shark" >>> images = pipe( ... prompt, ... guidance_scale=guidance_scale, ... num_inference_steps=64, ... frame_size=256, ... ).images >>> gif_path = export_to_gif(images[0], "shark_3d.gif") ``` """ @dataclass class ShapEPipelineOutput(BaseOutput): """ Output class for ShapEPipeline. Args: images (`torch.FloatTensor`) a list of images for 3D rendering """ images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]] class ShapEPipeline(DiffusionPipeline): """ Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`HeunDiscreteScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. renderer ([`ShapERenderer`]): Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects with the NeRF rendering method """ def __init__( self, prior: PriorTransformer, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, scheduler: HeunDiscreteScheduler, renderer: ShapERenderer, ): super().__init__() self.register_modules( prior=prior, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, renderer=renderer, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [self.text_encoder, self.prior] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): return self.device for module in self.text_encoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, ): len(prompt) if isinstance(prompt, list) else 1 # YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file self.tokenizer.pad_token_id = 0 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) # in Shap-E it normalize the prompt_embeds and then later rescale it prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(prompt_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # Rescale the features to have unit variance prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds return prompt_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: str, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, frame_size: int = 64, output_type: Optional[str] = "pil", # pil, np, latent return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. frame_size (`int`, *optional*, default to 64): the width and height of each image frame of the generated 3d output output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`ShapEPipelineOutput`] or `tuple` """ if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps num_embeddings = self.prior.config.num_embeddings embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, num_embeddings * embedding_dim), prompt_embeds.dtype, device, generator, latents, self.scheduler, ) # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) noise_pred = self.prior( scaled_model_input, timestep=t, proj_embedding=prompt_embeds, ).predicted_image_embedding # remove the variance noise_pred, _ = noise_pred.split( scaled_model_input.shape[2], dim=2 ) # batch_size, num_embeddings, embedding_dim if do_classifier_free_guidance is not None: noise_pred_uncond, noise_pred = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) latents = self.scheduler.step( noise_pred, timestep=t, sample=latents, ).prev_sample if output_type == "latent": return ShapEPipelineOutput(images=latents) images = [] for i, latent in enumerate(latents): image = self.renderer.decode( latent[None, :], device, size=frame_size, ray_batch_size=4096, n_coarse_samples=64, n_fine_samples=128, ) images.append(image) images = torch.stack(images) if output_type not in ["np", "pil"]: raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}") images = images.cpu().numpy() if output_type == "pil": images = [self.numpy_to_pil(image) for image in images] # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (images,) return ShapEPipelineOutput(images=images) ================================================ FILE: 6DoF/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py ================================================ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPVisionModel from ...models import PriorTransformer from ...pipelines import DiffusionPipeline from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, is_accelerate_available, logging, randn_tensor, replace_example_docstring, ) from .renderer import ShapERenderer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> from PIL import Image >>> import torch >>> from diffusers import DiffusionPipeline >>> from diffusers.utils import export_to_gif, load_image >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> repo = "openai/shap-e-img2img" >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> guidance_scale = 3.0 >>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png" >>> image = load_image(image_url).convert("RGB") >>> images = pipe( ... image, ... guidance_scale=guidance_scale, ... num_inference_steps=64, ... frame_size=256, ... ).images >>> gif_path = export_to_gif(images[0], "corgi_3d.gif") ``` """ @dataclass class ShapEPipelineOutput(BaseOutput): """ Output class for ShapEPipeline. Args: images (`torch.FloatTensor`) a list of images for 3D rendering """ images: Union[PIL.Image.Image, np.ndarray] class ShapEImg2ImgPipeline(DiffusionPipeline): """ Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler ([`HeunDiscreteScheduler`]): A scheduler to be used in combination with `prior` to generate image embedding. renderer ([`ShapERenderer`]): Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects with the NeRF rendering method """ def __init__( self, prior: PriorTransformer, image_encoder: CLIPVisionModel, image_processor: CLIPImageProcessor, scheduler: HeunDiscreteScheduler, renderer: ShapERenderer, ): super().__init__() self.register_modules( prior=prior, image_encoder=image_encoder, image_processor=image_processor, scheduler=scheduler, renderer=renderer, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [self.image_encoder, self.prior] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.image_encoder, "_hf_hook"): return self.device for module in self.image_encoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_image( self, image, device, num_images_per_prompt, do_classifier_free_guidance, ): if isinstance(image, List) and isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) if not isinstance(image, torch.Tensor): image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0) image = image.to(dtype=self.image_encoder.dtype, device=device) image_embeds = self.image_encoder(image)["last_hidden_state"] image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256 image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: negative_image_embeds = torch.zeros_like(image_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeds = torch.cat([negative_image_embeds, image_embeds]) return image_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image]], num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, guidance_scale: float = 4.0, frame_size: int = 64, output_type: Optional[str] = "pil", # pil, np, latent return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. frame_size (`int`, *optional*, default to 64): the width and height of each image frame of the generated 3d output output_type (`str`, *optional*, defaults to `"pt"`): The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` (`torch.Tensor`). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Examples: Returns: [`ShapEPipelineOutput`] or `tuple` """ if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, torch.Tensor): batch_size = image.shape[0] elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)): batch_size = len(image) else: raise ValueError( f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `List[PIL.Image.Image]` or `List[torch.Tensor]` but is {type(image)}" ) device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) # prior self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps num_embeddings = self.prior.config.num_embeddings embedding_dim = self.prior.config.embedding_dim latents = self.prepare_latents( (batch_size, num_embeddings * embedding_dim), image_embeds.dtype, device, generator, latents, self.scheduler, ) # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) noise_pred = self.prior( scaled_model_input, timestep=t, proj_embedding=image_embeds, ).predicted_image_embedding # remove the variance noise_pred, _ = noise_pred.split( scaled_model_input.shape[2], dim=2 ) # batch_size, num_embeddings, embedding_dim if do_classifier_free_guidance is not None: noise_pred_uncond, noise_pred = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) latents = self.scheduler.step( noise_pred, timestep=t, sample=latents, ).prev_sample if output_type == "latent": return ShapEPipelineOutput(images=latents) images = [] for i, latent in enumerate(latents): print() image = self.renderer.decode( latent[None, :], device, size=frame_size, ray_batch_size=4096, n_coarse_samples=64, n_fine_samples=128, ) images.append(image) images = torch.stack(images) if output_type not in ["np", "pil"]: raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}") images = images.cpu().numpy() if output_type == "pil": images = [self.numpy_to_pil(image) for image in images] # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (images,) return ShapEPipelineOutput(images=images) ================================================ FILE: 6DoF/diffusers/pipelines/shap_e/renderer.py ================================================ # Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...utils import BaseOutput from .camera import create_pan_cameras def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: r""" Sample from the given discrete probability distribution with replacement. The i-th bin is assumed to have mass pmf[i]. Args: pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() n_samples: number of samples Return: indices sampled with replacement """ *shape, support_size, last_dim = pmf.shape assert last_dim == 1 cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: """ Concatenate x and its positional encodings, following NeRF. Reference: https://arxiv.org/pdf/2210.04628.pdf """ if min_deg == max_deg: return x scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device) *shape, dim = x.shape xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) assert xb.shape[-1] == dim * (max_deg - min_deg) emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() return torch.cat([x, emb], dim=-1) def encode_position(position): return posenc_nerf(position, min_deg=0, max_deg=15) def encode_direction(position, direction=None): if direction is None: return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) else: return posenc_nerf(direction, min_deg=0, max_deg=8) def _sanitize_name(x: str) -> str: return x.replace(".", "__") def integrate_samples(volume_range, ts, density, channels): r""" Function integrating the model output. Args: volume_range: Specifies the integral range [t0, t1] ts: timesteps density: torch.Tensor [batch_size, *shape, n_samples, 1] channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] returns: channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume ) """ # 1. Calculate the weights _, _, dt = volume_range.partition(ts) ddensity = density * dt mass = torch.cumsum(ddensity, dim=-2) transmittance = torch.exp(-mass[..., -1, :]) alphas = 1.0 - torch.exp(-ddensity) Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) # This is the probability of light hitting and reflecting off of # something at depth [..., i, :]. weights = alphas * Ts # 2. Integrate channels channels = torch.sum(channels * weights, dim=-2) return channels, weights, transmittance class VoidNeRFModel(nn.Module): """ Implements the default empty space model where all queries are rendered as background. """ def __init__(self, background, channel_scale=255.0): super().__init__() background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale) self.register_buffer("background", background) def forward(self, position): background = self.background[None].to(position.device) shape = position.shape[:-1] ones = [1] * (len(shape) - 1) n_channels = background.shape[-1] background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]) return background @dataclass class VolumeRange: t0: torch.Tensor t1: torch.Tensor intersected: torch.Tensor def __post_init__(self): assert self.t0.shape == self.t1.shape == self.intersected.shape def partition(self, ts): """ Partitions t0 and t1 into n_samples intervals. Args: ts: [batch_size, *shape, n_samples, 1] Return: lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, *shape, n_samples, 1] where ts \\in [lower, upper] deltas = upper - lower """ mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 lower = torch.cat([self.t0[..., None, :], mids], dim=-2) upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) delta = upper - lower assert lower.shape == upper.shape == delta.shape == ts.shape return lower, upper, delta class BoundingBoxVolume(nn.Module): """ Axis-aligned bounding box defined by the two opposite corners. """ def __init__( self, *, bbox_min, bbox_max, min_dist: float = 0.0, min_t_range: float = 1e-3, ): """ Args: bbox_min: the left/bottommost corner of the bounding box bbox_max: the other corner of the bounding box min_dist: all rays should start at least this distance away from the origin. """ super().__init__() self.min_dist = min_dist self.min_t_range = min_t_range self.bbox_min = torch.tensor(bbox_min) self.bbox_max = torch.tensor(bbox_max) self.bbox = torch.stack([self.bbox_min, self.bbox_max]) assert self.bbox.shape == (2, 3) assert min_dist >= 0.0 assert min_t_range > 0.0 def intersect( self, origin: torch.Tensor, direction: torch.Tensor, t0_lower: Optional[torch.Tensor] = None, epsilon=1e-6, ): """ Args: origin: [batch_size, *shape, 3] direction: [batch_size, *shape, 3] t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. params: Optional meta parameters in case Volume is parametric epsilon: to stabilize calculations Return: A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to be on the boundary of the volume. """ batch_size, *shape, _ = origin.shape ones = [1] * len(shape) bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device) def _safe_divide(a, b, epsilon=1e-6): return a / torch.where(b < 0, b - epsilon, b + epsilon) ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) # Cases to think about: # # 1. t1 <= t0: the ray does not pass through the AABB. # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. # 3. t0 <= 0 <= t1: the ray starts from inside the BB # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. # # 1 and 4 are clearly handled from t0 < t1 below. # Making t0 at least min_dist (>= 0) takes care of 2 and 3. t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values assert t0.shape == t1.shape == (batch_size, *shape, 1) if t0_lower is not None: assert t0.shape == t0_lower.shape t0 = torch.maximum(t0, t0_lower) intersected = t0 + self.min_t_range < t1 t0 = torch.where(intersected, t0, torch.zeros_like(t0)) t1 = torch.where(intersected, t1, torch.ones_like(t1)) return VolumeRange(t0=t0, t1=t1, intersected=intersected) class StratifiedRaySampler(nn.Module): """ Instead of fixed intervals, a sample is drawn uniformly at random from each interval. """ def __init__(self, depth_mode: str = "linear"): """ :param depth_mode: linear samples ts linearly in depth. harmonic ensures closer points are sampled more densely. """ self.depth_mode = depth_mode assert self.depth_mode in ("linear", "geometric", "harmonic") def sample( self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int, epsilon: float = 1e-3, ) -> torch.Tensor: """ Args: t0: start time has shape [batch_size, *shape, 1] t1: finish time has shape [batch_size, *shape, 1] n_samples: number of ts to sample Return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ ones = [1] * (len(t0.shape) - 1) ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) if self.depth_mode == "linear": ts = t0 * (1.0 - ts) + t1 * ts elif self.depth_mode == "geometric": ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() elif self.depth_mode == "harmonic": # The original NeRF recommends this interpolation scheme for # spherical scenes, but there could be some weird edge cases when # the observer crosses from the inner to outer volume. ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) upper = torch.cat([mids, t1], dim=-1) lower = torch.cat([t0, mids], dim=-1) # yiyi notes: add a random seed here for testing, don't forget to remove torch.manual_seed(0) t_rand = torch.rand_like(ts) ts = lower + (upper - lower) * t_rand return ts.unsqueeze(-1) class ImportanceRaySampler(nn.Module): """ Given the initial estimate of densities, this samples more from regions/bins expected to have objects. """ def __init__( self, volume_range: VolumeRange, ts: torch.Tensor, weights: torch.Tensor, blur_pool: bool = False, alpha: float = 1e-5, ): """ Args: volume_range: the range in which a ray intersects the given volume. ts: earlier samples from the coarse rendering step weights: discretized version of density * transmittance blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. alpha: small value to add to weights. """ self.volume_range = volume_range self.ts = ts.clone().detach() self.weights = weights.clone().detach() self.blur_pool = blur_pool self.alpha = alpha @torch.no_grad() def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: """ Args: t0: start time has shape [batch_size, *shape, 1] t1: finish time has shape [batch_size, *shape, 1] n_samples: number of ts to sample Return: sampled ts of shape [batch_size, *shape, n_samples, 1] """ lower, upper, _ = self.volume_range.partition(self.ts) batch_size, *shape, n_coarse_samples, _ = self.ts.shape weights = self.weights if self.blur_pool: padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) weights = weights + self.alpha pmf = weights / weights.sum(dim=-2, keepdim=True) inds = sample_pmf(pmf, n_samples) assert inds.shape == (batch_size, *shape, n_samples, 1) assert (inds >= 0).all() and (inds < n_coarse_samples).all() t_rand = torch.rand(inds.shape, device=inds.device) lower_ = torch.gather(lower, -2, inds) upper_ = torch.gather(upper, -2, inds) ts = lower_ + (upper_ - lower_) * t_rand ts = torch.sort(ts, dim=-2).values return ts @dataclass class MLPNeRFModelOutput(BaseOutput): density: torch.Tensor signed_distance: torch.Tensor channels: torch.Tensor ts: torch.Tensor class MLPNeRSTFModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, d_hidden: int = 256, n_output: int = 12, n_hidden_layers: int = 6, act_fn: str = "swish", insert_direction_at: int = 4, ): super().__init__() # Instantiate the MLP # Find out the dimension of encoded position and direction dummy = torch.eye(1, 3) d_posenc_pos = encode_position(position=dummy).shape[-1] d_posenc_dir = encode_direction(position=dummy).shape[-1] mlp_widths = [d_hidden] * n_hidden_layers input_widths = [d_posenc_pos] + mlp_widths output_widths = mlp_widths + [n_output] if insert_direction_at is not None: input_widths[insert_direction_at] += d_posenc_dir self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)]) if act_fn == "swish": # self.activation = swish # yiyi testing: self.activation = lambda x: F.silu(x) else: raise ValueError(f"Unsupported activation function {act_fn}") self.sdf_activation = torch.tanh self.density_activation = torch.nn.functional.relu self.channel_activation = torch.sigmoid def map_indices_to_keys(self, output): h_map = { "sdf": (0, 1), "density_coarse": (1, 2), "density_fine": (2, 3), "stf": (3, 6), "nerf_coarse": (6, 9), "nerf_fine": (9, 12), } mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()} return mapped_output def forward(self, *, position, direction, ts, nerf_level="coarse"): h = encode_position(position) h_preact = h h_directionless = None for i, layer in enumerate(self.mlp): if i == self.config.insert_direction_at: # 4 in the config h_directionless = h_preact h_direction = encode_direction(position, direction=direction) h = torch.cat([h, h_direction], dim=-1) h = layer(h) h_preact = h if i < len(self.mlp) - 1: h = self.activation(h) h_final = h if h_directionless is None: h_directionless = h_preact activation = self.map_indices_to_keys(h_final) if nerf_level == "coarse": h_density = activation["density_coarse"] h_channels = activation["nerf_coarse"] else: h_density = activation["density_fine"] h_channels = activation["nerf_fine"] density = self.density_activation(h_density) signed_distance = self.sdf_activation(activation["sdf"]) channels = self.channel_activation(h_channels) # yiyi notes: I think signed_distance is not used return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts) class ChannelsProj(nn.Module): def __init__( self, *, vectors: int, channels: int, d_latent: int, ): super().__init__() self.proj = nn.Linear(d_latent, vectors * channels) self.norm = nn.LayerNorm(channels) self.d_latent = d_latent self.vectors = vectors self.channels = channels def forward(self, x: torch.Tensor) -> torch.Tensor: x_bvd = x w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) b_vc = self.proj.bias.view(1, self.vectors, self.channels) h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) h = self.norm(h) h = h + b_vc return h class ShapEParamsProjModel(ModelMixin, ConfigMixin): """ project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). For more details, see the original paper: """ @register_to_config def __init__( self, *, param_names: Tuple[str] = ( "nerstf.mlp.0.weight", "nerstf.mlp.1.weight", "nerstf.mlp.2.weight", "nerstf.mlp.3.weight", ), param_shapes: Tuple[Tuple[int]] = ( (256, 93), (256, 256), (256, 256), (256, 256), ), d_latent: int = 1024, ): super().__init__() # check inputs if len(param_names) != len(param_shapes): raise ValueError("Must provide same number of `param_names` as `param_shapes`") self.projections = nn.ModuleDict({}) for k, (vectors, channels) in zip(param_names, param_shapes): self.projections[_sanitize_name(k)] = ChannelsProj( vectors=vectors, channels=channels, d_latent=d_latent, ) def forward(self, x: torch.Tensor): out = {} start = 0 for k, shape in zip(self.config.param_names, self.config.param_shapes): vectors, _ = shape end = start + vectors x_bvd = x[:, start:end] out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) start = end return out class ShapERenderer(ModelMixin, ConfigMixin): @register_to_config def __init__( self, *, param_names: Tuple[str] = ( "nerstf.mlp.0.weight", "nerstf.mlp.1.weight", "nerstf.mlp.2.weight", "nerstf.mlp.3.weight", ), param_shapes: Tuple[Tuple[int]] = ( (256, 93), (256, 256), (256, 256), (256, 256), ), d_latent: int = 1024, d_hidden: int = 256, n_output: int = 12, n_hidden_layers: int = 6, act_fn: str = "swish", insert_direction_at: int = 4, background: Tuple[float] = ( 255.0, 255.0, 255.0, ), ): super().__init__() self.params_proj = ShapEParamsProjModel( param_names=param_names, param_shapes=param_shapes, d_latent=d_latent, ) self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) self.void = VoidNeRFModel(background=background, channel_scale=255.0) self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) @torch.no_grad() def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): """ Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below with some abuse of notations) C(r) := sum( transmittance(t[i]) * integrate( lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], ) for i in range(len(parts)) ) + transmittance(t[-1]) * void_model(t[-1]).channels where 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). args: rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples: number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including :return: A tuple of - `channels` - A importance samplers for additional fine-grained rendering - raw model output """ origin, direction = rays[..., 0, :], rays[..., 1, :] # Integrate over [t[i], t[i + 1]] # 1 Intersect the rays with the current volume and sample ts to integrate along. vrange = self.volume.intersect(origin, direction, t0_lower=None) ts = sampler.sample(vrange.t0, vrange.t1, n_samples) ts = ts.to(rays.dtype) if prev_model_out is not None: # Append the previous ts now before fprop because previous # rendering used a different model and we can't reuse the output. ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values batch_size, *_shape, _t0_dim = vrange.t0.shape _, *ts_shape, _ts_dim = ts.shape # 2. Get the points along the ray and query the model directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) positions = origin.unsqueeze(-2) + ts * directions directions = directions.to(self.mlp.dtype) positions = positions.to(self.mlp.dtype) optional_directions = directions if render_with_direction else None model_out = self.mlp( position=positions, direction=optional_directions, ts=ts, nerf_level="coarse" if prev_model_out is None else "fine", ) # 3. Integrate the model results channels, weights, transmittance = integrate_samples( vrange, model_out.ts, model_out.density, model_out.channels ) # 4. Clean up results that do not intersect with the volume. transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance)) channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels)) # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). channels = channels + transmittance * self.void(origin) weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights) return channels, weighted_sampler, model_out @torch.no_grad() def decode( self, latents, device, size: int = 64, ray_batch_size: int = 4096, n_coarse_samples=64, n_fine_samples=128, ): # project the the paramters from the generated latents projected_params = self.params_proj(latents) # update the mlp layers of the renderer for name, param in self.mlp.state_dict().items(): if f"nerstf.{name}" in projected_params.keys(): param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) # create cameras object camera = create_pan_cameras(size) rays = camera.camera_rays rays = rays.to(device) n_batches = rays.shape[1] // ray_batch_size coarse_sampler = StratifiedRaySampler() images = [] for idx in range(n_batches): rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] # render rays with coarse, stratified samples. _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples) # Then, render with additional importance-weighted ray samples. channels, _, _ = self.render_rays( rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out ) images.append(channels) images = torch.cat(images, dim=1) images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) return images ================================================ FILE: 6DoF/diffusers/pipelines/spectrogram_diffusion/__init__.py ================================================ # flake8: noqa from ...utils import is_note_seq_available, is_transformers_available, is_torch_available from ...utils import OptionalDependencyNotAvailable try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .notes_encoder import SpectrogramNotesEncoder from .continous_encoder import SpectrogramContEncoder from .pipeline_spectrogram_diffusion import ( SpectrogramContEncoder, SpectrogramDiffusionPipeline, T5FilmDecoder, ) try: if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: from .midi_utils import MidiProcessor ================================================ FILE: 6DoF/diffusers/pipelines/spectrogram_diffusion/continous_encoder.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.t5.modeling_t5 import ( T5Block, T5Config, T5LayerNorm, ) from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): @register_to_config def __init__( self, input_dims: int, targets_context_length: int, d_model: int, dropout_rate: float, num_layers: int, num_heads: int, d_kv: int, d_ff: int, feed_forward_proj: str, is_decoder: bool = False, ): super().__init__() self.input_proj = nn.Linear(input_dims, d_model, bias=False) self.position_encoding = nn.Embedding(targets_context_length, d_model) self.position_encoding.weight.requires_grad = False self.dropout_pre = nn.Dropout(p=dropout_rate) t5config = T5Config( d_model=d_model, num_heads=num_heads, d_kv=d_kv, d_ff=d_ff, feed_forward_proj=feed_forward_proj, dropout_rate=dropout_rate, is_decoder=is_decoder, is_encoder_decoder=False, ) self.encoders = nn.ModuleList() for lyr_num in range(num_layers): lyr = T5Block(t5config) self.encoders.append(lyr) self.layer_norm = T5LayerNorm(d_model) self.dropout_post = nn.Dropout(p=dropout_rate) def forward(self, encoder_inputs, encoder_inputs_mask): x = self.input_proj(encoder_inputs) # terminal relative positional encodings max_positions = encoder_inputs.shape[1] input_positions = torch.arange(max_positions, device=encoder_inputs.device) seq_lens = encoder_inputs_mask.sum(-1) input_positions = torch.roll(input_positions.unsqueeze(0), tuple(seq_lens.tolist()), dims=0) x += self.position_encoding(input_positions) x = self.dropout_pre(x) # inverted the attention mask input_shape = encoder_inputs.size() extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) for lyr in self.encoders: x = lyr(x, extended_attention_mask)[0] x = self.layer_norm(x) return self.dropout_post(x), encoder_inputs_mask ================================================ FILE: 6DoF/diffusers/pipelines/spectrogram_diffusion/midi_utils.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import math import os from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn.functional as F from ...utils import is_note_seq_available from .pipeline_spectrogram_diffusion import TARGET_FEATURE_LENGTH if is_note_seq_available(): import note_seq else: raise ImportError("Please install note-seq via `pip install note-seq`") INPUT_FEATURE_LENGTH = 2048 SAMPLE_RATE = 16000 HOP_SIZE = 320 FRAME_RATE = int(SAMPLE_RATE // HOP_SIZE) DEFAULT_STEPS_PER_SECOND = 100 DEFAULT_MAX_SHIFT_SECONDS = 10 DEFAULT_NUM_VELOCITY_BINS = 1 SLAKH_CLASS_PROGRAMS = { "Acoustic Piano": 0, "Electric Piano": 4, "Chromatic Percussion": 8, "Organ": 16, "Acoustic Guitar": 24, "Clean Electric Guitar": 26, "Distorted Electric Guitar": 29, "Acoustic Bass": 32, "Electric Bass": 33, "Violin": 40, "Viola": 41, "Cello": 42, "Contrabass": 43, "Orchestral Harp": 46, "Timpani": 47, "String Ensemble": 48, "Synth Strings": 50, "Choir and Voice": 52, "Orchestral Hit": 55, "Trumpet": 56, "Trombone": 57, "Tuba": 58, "French Horn": 60, "Brass Section": 61, "Soprano/Alto Sax": 64, "Tenor Sax": 66, "Baritone Sax": 67, "Oboe": 68, "English Horn": 69, "Bassoon": 70, "Clarinet": 71, "Pipe": 73, "Synth Lead": 80, "Synth Pad": 88, } @dataclasses.dataclass class NoteRepresentationConfig: """Configuration note representations.""" onsets_only: bool include_ties: bool @dataclasses.dataclass class NoteEventData: pitch: int velocity: Optional[int] = None program: Optional[int] = None is_drum: Optional[bool] = None instrument: Optional[int] = None @dataclasses.dataclass class NoteEncodingState: """Encoding state for note transcription, keeping track of active pitches.""" # velocity bin for active pitches and programs active_pitches: MutableMapping[Tuple[int, int], int] = dataclasses.field(default_factory=dict) @dataclasses.dataclass class EventRange: type: str min_value: int max_value: int @dataclasses.dataclass class Event: type: str value: int class Tokenizer: def __init__(self, regular_ids: int): # The special tokens: 0=PAD, 1=EOS, and 2=UNK self._num_special_tokens = 3 self._num_regular_tokens = regular_ids def encode(self, token_ids): encoded = [] for token_id in token_ids: if not 0 <= token_id < self._num_regular_tokens: raise ValueError( f"token_id {token_id} does not fall within valid range of [0, {self._num_regular_tokens})" ) encoded.append(token_id + self._num_special_tokens) # Add EOS token encoded.append(1) # Pad to till INPUT_FEATURE_LENGTH encoded = encoded + [0] * (INPUT_FEATURE_LENGTH - len(encoded)) return encoded class Codec: """Encode and decode events. Useful for declaring what certain ranges of a vocabulary should be used for. This is intended to be used from Python before encoding or after decoding with GenericTokenVocabulary. This class is more lightweight and does not include things like EOS or UNK token handling. To ensure that 'shift' events are always the first block of the vocab and start at 0, that event type is required and specified separately. """ def __init__(self, max_shift_steps: int, steps_per_second: float, event_ranges: List[EventRange]): """Define Codec. Args: max_shift_steps: Maximum number of shift steps that can be encoded. steps_per_second: Shift steps will be interpreted as having a duration of 1 / steps_per_second. event_ranges: Other supported event types and their ranges. """ self.steps_per_second = steps_per_second self._shift_range = EventRange(type="shift", min_value=0, max_value=max_shift_steps) self._event_ranges = [self._shift_range] + event_ranges # Ensure all event types have unique names. assert len(self._event_ranges) == len({er.type for er in self._event_ranges}) @property def num_classes(self) -> int: return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) # The next couple methods are simplified special case methods just for shift # events that are intended to be used from within autograph functions. def is_shift_event_index(self, index: int) -> bool: return (self._shift_range.min_value <= index) and (index <= self._shift_range.max_value) @property def max_shift_steps(self) -> int: return self._shift_range.max_value def encode_event(self, event: Event) -> int: """Encode an event to an index.""" offset = 0 for er in self._event_ranges: if event.type == er.type: if not er.min_value <= event.value <= er.max_value: raise ValueError( f"Event value {event.value} is not within valid range " f"[{er.min_value}, {er.max_value}] for type {event.type}" ) return offset + event.value - er.min_value offset += er.max_value - er.min_value + 1 raise ValueError(f"Unknown event type: {event.type}") def event_type_range(self, event_type: str) -> Tuple[int, int]: """Return [min_id, max_id] for an event type.""" offset = 0 for er in self._event_ranges: if event_type == er.type: return offset, offset + (er.max_value - er.min_value) offset += er.max_value - er.min_value + 1 raise ValueError(f"Unknown event type: {event_type}") def decode_event_index(self, index: int) -> Event: """Decode an event index to an Event.""" offset = 0 for er in self._event_ranges: if offset <= index <= offset + er.max_value - er.min_value: return Event(type=er.type, value=er.min_value + index - offset) offset += er.max_value - er.min_value + 1 raise ValueError(f"Unknown event index: {index}") @dataclasses.dataclass class ProgramGranularity: # both tokens_map_fn and program_map_fn should be idempotent tokens_map_fn: Callable[[Sequence[int], Codec], Sequence[int]] program_map_fn: Callable[[int], int] def drop_programs(tokens, codec: Codec): """Drops program change events from a token sequence.""" min_program_id, max_program_id = codec.event_type_range("program") return tokens[(tokens < min_program_id) | (tokens > max_program_id)] def programs_to_midi_classes(tokens, codec): """Modifies program events to be the first program in the MIDI class.""" min_program_id, max_program_id = codec.event_type_range("program") is_program = (tokens >= min_program_id) & (tokens <= max_program_id) return np.where(is_program, min_program_id + 8 * ((tokens - min_program_id) // 8), tokens) PROGRAM_GRANULARITIES = { # "flat" granularity; drop program change tokens and set NoteSequence # programs to zero "flat": ProgramGranularity(tokens_map_fn=drop_programs, program_map_fn=lambda program: 0), # map each program to the first program in its MIDI class "midi_class": ProgramGranularity( tokens_map_fn=programs_to_midi_classes, program_map_fn=lambda program: 8 * (program // 8) ), # leave programs as is "full": ProgramGranularity(tokens_map_fn=lambda tokens, codec: tokens, program_map_fn=lambda program: program), } def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1): """ equivalent of tf.signal.frame """ signal_length = signal.shape[axis] if pad_end: frames_overlap = frame_length - frame_step rest_samples = np.abs(signal_length - frames_overlap) % np.abs(frame_length - frames_overlap) pad_size = int(frame_length - rest_samples) if pad_size != 0: pad_axis = [0] * signal.ndim pad_axis[axis] = pad_size signal = F.pad(signal, pad_axis, "constant", pad_value) frames = signal.unfold(axis, frame_length, frame_step) return frames def program_to_slakh_program(program): # this is done very hackily, probably should use a custom mapping for slakh_program in sorted(SLAKH_CLASS_PROGRAMS.values(), reverse=True): if program >= slakh_program: return slakh_program def audio_to_frames( samples, hop_size: int, frame_rate: int, ) -> Tuple[Sequence[Sequence[int]], torch.Tensor]: """Convert audio samples to non-overlapping frames and frame times.""" frame_size = hop_size samples = np.pad(samples, [0, frame_size - len(samples) % frame_size], mode="constant") # Split audio into frames. frames = frame( torch.Tensor(samples).unsqueeze(0), frame_length=frame_size, frame_step=frame_size, pad_end=False, # TODO check why its off by 1 here when True ) num_frames = len(samples) // frame_size times = np.arange(num_frames) / frame_rate return frames, times def note_sequence_to_onsets_and_offsets_and_programs( ns: note_seq.NoteSequence, ) -> Tuple[Sequence[float], Sequence[NoteEventData]]: """Extract onset & offset times and pitches & programs from a NoteSequence. The onset & offset times will not necessarily be in sorted order. Args: ns: NoteSequence from which to extract onsets and offsets. Returns: times: A list of note onset and offset times. values: A list of NoteEventData objects where velocity is zero for note offsets. """ # Sort by program and pitch and put offsets before onsets as a tiebreaker for # subsequent stable sort. notes = sorted(ns.notes, key=lambda note: (note.is_drum, note.program, note.pitch)) times = [note.end_time for note in notes if not note.is_drum] + [note.start_time for note in notes] values = [ NoteEventData(pitch=note.pitch, velocity=0, program=note.program, is_drum=False) for note in notes if not note.is_drum ] + [ NoteEventData(pitch=note.pitch, velocity=note.velocity, program=note.program, is_drum=note.is_drum) for note in notes ] return times, values def num_velocity_bins_from_codec(codec: Codec): """Get number of velocity bins from event codec.""" lo, hi = codec.event_type_range("velocity") return hi - lo # segment an array into segments of length n def segment(a, n): return [a[i : i + n] for i in range(0, len(a), n)] def velocity_to_bin(velocity, num_velocity_bins): if velocity == 0: return 0 else: return math.ceil(num_velocity_bins * velocity / note_seq.MAX_MIDI_VELOCITY) def note_event_data_to_events( state: Optional[NoteEncodingState], value: NoteEventData, codec: Codec, ) -> Sequence[Event]: """Convert note event data to a sequence of events.""" if value.velocity is None: # onsets only, no program or velocity return [Event("pitch", value.pitch)] else: num_velocity_bins = num_velocity_bins_from_codec(codec) velocity_bin = velocity_to_bin(value.velocity, num_velocity_bins) if value.program is None: # onsets + offsets + velocities only, no programs if state is not None: state.active_pitches[(value.pitch, 0)] = velocity_bin return [Event("velocity", velocity_bin), Event("pitch", value.pitch)] else: if value.is_drum: # drum events use a separate vocabulary return [Event("velocity", velocity_bin), Event("drum", value.pitch)] else: # program + velocity + pitch if state is not None: state.active_pitches[(value.pitch, value.program)] = velocity_bin return [ Event("program", value.program), Event("velocity", velocity_bin), Event("pitch", value.pitch), ] def note_encoding_state_to_events(state: NoteEncodingState) -> Sequence[Event]: """Output program and pitch events for active notes plus a final tie event.""" events = [] for pitch, program in sorted(state.active_pitches.keys(), key=lambda k: k[::-1]): if state.active_pitches[(pitch, program)]: events += [Event("program", program), Event("pitch", pitch)] events.append(Event("tie", 0)) return events def encode_and_index_events( state, event_times, event_values, codec, frame_times, encode_event_fn, encoding_state_to_events_fn=None ): """Encode a sequence of timed events and index to audio frame times. Encodes time shifts as repeated single step shifts for later run length encoding. Optionally, also encodes a sequence of "state events", keeping track of the current encoding state at each audio frame. This can be used e.g. to prepend events representing the current state to a targets segment. Args: state: Initial event encoding state. event_times: Sequence of event times. event_values: Sequence of event values. encode_event_fn: Function that transforms event value into a sequence of one or more Event objects. codec: An Codec object that maps Event objects to indices. frame_times: Time for every audio frame. encoding_state_to_events_fn: Function that transforms encoding state into a sequence of one or more Event objects. Returns: events: Encoded events and shifts. event_start_indices: Corresponding start event index for every audio frame. Note: one event can correspond to multiple audio indices due to sampling rate differences. This makes splitting sequences tricky because the same event can appear at the end of one sequence and the beginning of another. event_end_indices: Corresponding end event index for every audio frame. Used to ensure when slicing that one chunk ends where the next begins. Should always be true that event_end_indices[i] = event_start_indices[i + 1]. state_events: Encoded "state" events representing the encoding state before each event. state_event_indices: Corresponding state event index for every audio frame. """ indices = np.argsort(event_times, kind="stable") event_steps = [round(event_times[i] * codec.steps_per_second) for i in indices] event_values = [event_values[i] for i in indices] events = [] state_events = [] event_start_indices = [] state_event_indices = [] cur_step = 0 cur_event_idx = 0 cur_state_event_idx = 0 def fill_event_start_indices_to_cur_step(): while ( len(event_start_indices) < len(frame_times) and frame_times[len(event_start_indices)] < cur_step / codec.steps_per_second ): event_start_indices.append(cur_event_idx) state_event_indices.append(cur_state_event_idx) for event_step, event_value in zip(event_steps, event_values): while event_step > cur_step: events.append(codec.encode_event(Event(type="shift", value=1))) cur_step += 1 fill_event_start_indices_to_cur_step() cur_event_idx = len(events) cur_state_event_idx = len(state_events) if encoding_state_to_events_fn: # Dump state to state events *before* processing the next event, because # we want to capture the state prior to the occurrence of the event. for e in encoding_state_to_events_fn(state): state_events.append(codec.encode_event(e)) for e in encode_event_fn(state, event_value, codec): events.append(codec.encode_event(e)) # After the last event, continue filling out the event_start_indices array. # The inequality is not strict because if our current step lines up exactly # with (the start of) an audio frame, we need to add an additional shift event # to "cover" that frame. while cur_step / codec.steps_per_second <= frame_times[-1]: events.append(codec.encode_event(Event(type="shift", value=1))) cur_step += 1 fill_event_start_indices_to_cur_step() cur_event_idx = len(events) # Now fill in event_end_indices. We need this extra array to make sure that # when we slice events, each slice ends exactly where the subsequent slice # begins. event_end_indices = event_start_indices[1:] + [len(events)] events = np.array(events).astype(np.int32) state_events = np.array(state_events).astype(np.int32) event_start_indices = segment(np.array(event_start_indices).astype(np.int32), TARGET_FEATURE_LENGTH) event_end_indices = segment(np.array(event_end_indices).astype(np.int32), TARGET_FEATURE_LENGTH) state_event_indices = segment(np.array(state_event_indices).astype(np.int32), TARGET_FEATURE_LENGTH) outputs = [] for start_indices, end_indices, event_indices in zip(event_start_indices, event_end_indices, state_event_indices): outputs.append( { "inputs": events, "event_start_indices": start_indices, "event_end_indices": end_indices, "state_events": state_events, "state_event_indices": event_indices, } ) return outputs def extract_sequence_with_indices(features, state_events_end_token=None, feature_key="inputs"): """Extract target sequence corresponding to audio token segment.""" features = features.copy() start_idx = features["event_start_indices"][0] end_idx = features["event_end_indices"][-1] features[feature_key] = features[feature_key][start_idx:end_idx] if state_events_end_token is not None: # Extract the state events corresponding to the audio start token, and # prepend them to the targets array. state_event_start_idx = features["state_event_indices"][0] state_event_end_idx = state_event_start_idx + 1 while features["state_events"][state_event_end_idx - 1] != state_events_end_token: state_event_end_idx += 1 features[feature_key] = np.concatenate( [ features["state_events"][state_event_start_idx:state_event_end_idx], features[feature_key], ], axis=0, ) return features def map_midi_programs( feature, codec: Codec, granularity_type: str = "full", feature_key: str = "inputs" ) -> Mapping[str, Any]: """Apply MIDI program map to token sequences.""" granularity = PROGRAM_GRANULARITIES[granularity_type] feature[feature_key] = granularity.tokens_map_fn(feature[feature_key], codec) return feature def run_length_encode_shifts_fn( features, codec: Codec, feature_key: str = "inputs", state_change_event_types: Sequence[str] = (), ) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: """Return a function that run-length encodes shifts for a given codec. Args: codec: The Codec to use for shift events. feature_key: The feature key for which to run-length encode shifts. state_change_event_types: A list of event types that represent state changes; tokens corresponding to these event types will be interpreted as state changes and redundant ones will be removed. Returns: A preprocessing function that run-length encodes single-step shifts. """ state_change_event_ranges = [codec.event_type_range(event_type) for event_type in state_change_event_types] def run_length_encode_shifts(features: MutableMapping[str, Any]) -> Mapping[str, Any]: """Combine leading/interior shifts, trim trailing shifts. Args: features: Dict of features to process. Returns: A dict of features. """ events = features[feature_key] shift_steps = 0 total_shift_steps = 0 output = np.array([], dtype=np.int32) current_state = np.zeros(len(state_change_event_ranges), dtype=np.int32) for event in events: if codec.is_shift_event_index(event): shift_steps += 1 total_shift_steps += 1 else: # If this event is a state change and has the same value as the current # state, we can skip it entirely. is_redundant = False for i, (min_index, max_index) in enumerate(state_change_event_ranges): if (min_index <= event) and (event <= max_index): if current_state[i] == event: is_redundant = True current_state[i] = event if is_redundant: continue # Once we've reached a non-shift event, RLE all previous shift events # before outputting the non-shift event. if shift_steps > 0: shift_steps = total_shift_steps while shift_steps > 0: output_steps = np.minimum(codec.max_shift_steps, shift_steps) output = np.concatenate([output, [output_steps]], axis=0) shift_steps -= output_steps output = np.concatenate([output, [event]], axis=0) features[feature_key] = output return features return run_length_encode_shifts(features) def note_representation_processor_chain(features, codec: Codec, note_representation_config: NoteRepresentationConfig): tie_token = codec.encode_event(Event("tie", 0)) state_events_end_token = tie_token if note_representation_config.include_ties else None features = extract_sequence_with_indices( features, state_events_end_token=state_events_end_token, feature_key="inputs" ) features = map_midi_programs(features, codec) features = run_length_encode_shifts_fn(features, codec, state_change_event_types=["velocity", "program"]) return features class MidiProcessor: def __init__(self): self.codec = Codec( max_shift_steps=DEFAULT_MAX_SHIFT_SECONDS * DEFAULT_STEPS_PER_SECOND, steps_per_second=DEFAULT_STEPS_PER_SECOND, event_ranges=[ EventRange("pitch", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), EventRange("velocity", 0, DEFAULT_NUM_VELOCITY_BINS), EventRange("tie", 0, 0), EventRange("program", note_seq.MIN_MIDI_PROGRAM, note_seq.MAX_MIDI_PROGRAM), EventRange("drum", note_seq.MIN_MIDI_PITCH, note_seq.MAX_MIDI_PITCH), ], ) self.tokenizer = Tokenizer(self.codec.num_classes) self.note_representation_config = NoteRepresentationConfig(onsets_only=False, include_ties=True) def __call__(self, midi: Union[bytes, os.PathLike, str]): if not isinstance(midi, bytes): with open(midi, "rb") as f: midi = f.read() ns = note_seq.midi_to_note_sequence(midi) ns_sus = note_seq.apply_sustain_control_changes(ns) for note in ns_sus.notes: if not note.is_drum: note.program = program_to_slakh_program(note.program) samples = np.zeros(int(ns_sus.total_time * SAMPLE_RATE)) _, frame_times = audio_to_frames(samples, HOP_SIZE, FRAME_RATE) times, values = note_sequence_to_onsets_and_offsets_and_programs(ns_sus) events = encode_and_index_events( state=NoteEncodingState(), event_times=times, event_values=values, frame_times=frame_times, codec=self.codec, encode_event_fn=note_event_data_to_events, encoding_state_to_events_fn=note_encoding_state_to_events, ) events = [ note_representation_processor_chain(event, self.codec, self.note_representation_config) for event in events ] input_tokens = [self.tokenizer.encode(event["inputs"]) for event in events] return input_tokens ================================================ FILE: 6DoF/diffusers/pipelines/spectrogram_diffusion/notes_encoder.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.t5.modeling_t5 import T5Block, T5Config, T5LayerNorm from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin class SpectrogramNotesEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): @register_to_config def __init__( self, max_length: int, vocab_size: int, d_model: int, dropout_rate: float, num_layers: int, num_heads: int, d_kv: int, d_ff: int, feed_forward_proj: str, is_decoder: bool = False, ): super().__init__() self.token_embedder = nn.Embedding(vocab_size, d_model) self.position_encoding = nn.Embedding(max_length, d_model) self.position_encoding.weight.requires_grad = False self.dropout_pre = nn.Dropout(p=dropout_rate) t5config = T5Config( vocab_size=vocab_size, d_model=d_model, num_heads=num_heads, d_kv=d_kv, d_ff=d_ff, dropout_rate=dropout_rate, feed_forward_proj=feed_forward_proj, is_decoder=is_decoder, is_encoder_decoder=False, ) self.encoders = nn.ModuleList() for lyr_num in range(num_layers): lyr = T5Block(t5config) self.encoders.append(lyr) self.layer_norm = T5LayerNorm(d_model) self.dropout_post = nn.Dropout(p=dropout_rate) def forward(self, encoder_input_tokens, encoder_inputs_mask): x = self.token_embedder(encoder_input_tokens) seq_length = encoder_input_tokens.shape[1] inputs_positions = torch.arange(seq_length, device=encoder_input_tokens.device) x += self.position_encoding(inputs_positions) x = self.dropout_pre(x) # inverted the attention mask input_shape = encoder_input_tokens.size() extended_attention_mask = self.get_extended_attention_mask(encoder_inputs_mask, input_shape) for lyr in self.encoders: x = lyr(x, extended_attention_mask)[0] x = self.layer_norm(x) return self.dropout_post(x), encoder_inputs_mask ================================================ FILE: 6DoF/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py ================================================ # Copyright 2022 The Music Spectrogram Diffusion Authors. # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch from ...models import T5FilmDecoder from ...schedulers import DDPMScheduler from ...utils import is_onnx_available, logging, randn_tensor if is_onnx_available(): from ..onnx_utils import OnnxRuntimeModel from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .continous_encoder import SpectrogramContEncoder from .notes_encoder import SpectrogramNotesEncoder logger = logging.get_logger(__name__) # pylint: disable=invalid-name TARGET_FEATURE_LENGTH = 256 class SpectrogramDiffusionPipeline(DiffusionPipeline): _optional_components = ["melgan"] def __init__( self, notes_encoder: SpectrogramNotesEncoder, continuous_encoder: SpectrogramContEncoder, decoder: T5FilmDecoder, scheduler: DDPMScheduler, melgan: OnnxRuntimeModel if is_onnx_available() else Any, ) -> None: super().__init__() # From MELGAN self.min_value = math.log(1e-5) # Matches MelGAN training. self.max_value = 4.0 # Largest value for most examples self.n_dims = 128 self.register_modules( notes_encoder=notes_encoder, continuous_encoder=continuous_encoder, decoder=decoder, scheduler=scheduler, melgan=melgan, ) def scale_features(self, features, output_range=(-1.0, 1.0), clip=False): """Linearly scale features to network outputs range.""" min_out, max_out = output_range if clip: features = torch.clip(features, self.min_value, self.max_value) # Scale to [0, 1]. zero_one = (features - self.min_value) / (self.max_value - self.min_value) # Scale to [min_out, max_out]. return zero_one * (max_out - min_out) + min_out def scale_to_features(self, outputs, input_range=(-1.0, 1.0), clip=False): """Invert by linearly scaling network outputs to features range.""" min_out, max_out = input_range outputs = torch.clip(outputs, min_out, max_out) if clip else outputs # Scale to [0, 1]. zero_one = (outputs - min_out) / (max_out - min_out) # Scale to [self.min_value, self.max_value]. return zero_one * (self.max_value - self.min_value) + self.min_value def encode(self, input_tokens, continuous_inputs, continuous_mask): tokens_mask = input_tokens > 0 tokens_encoded, tokens_mask = self.notes_encoder( encoder_input_tokens=input_tokens, encoder_inputs_mask=tokens_mask ) continuous_encoded, continuous_mask = self.continuous_encoder( encoder_inputs=continuous_inputs, encoder_inputs_mask=continuous_mask ) return [(tokens_encoded, tokens_mask), (continuous_encoded, continuous_mask)] def decode(self, encodings_and_masks, input_tokens, noise_time): timesteps = noise_time if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=input_tokens.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(input_tokens.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(input_tokens.shape[0], dtype=timesteps.dtype, device=timesteps.device) logits = self.decoder( encodings_and_masks=encodings_and_masks, decoder_input_tokens=input_tokens, decoder_noise_time=timesteps ) return logits @torch.no_grad() def __call__( self, input_tokens: List[List[int]], generator: Optional[torch.Generator] = None, num_inference_steps: int = 100, return_dict: bool = True, output_type: str = "numpy", callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ) -> Union[AudioPipelineOutput, Tuple]: if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) pred_mel = np.zeros([1, TARGET_FEATURE_LENGTH, self.n_dims], dtype=np.float32) full_pred_mel = np.zeros([1, 0, self.n_dims], np.float32) ones = torch.ones((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) for i, encoder_input_tokens in enumerate(input_tokens): if i == 0: encoder_continuous_inputs = torch.from_numpy(pred_mel[:1].copy()).to( device=self.device, dtype=self.decoder.dtype ) # The first chunk has no previous context. encoder_continuous_mask = torch.zeros((1, TARGET_FEATURE_LENGTH), dtype=bool, device=self.device) else: # The full song pipeline does not feed in a context feature, so the mask # will be all 0s after the feature converter. Because we know we're # feeding in a full context chunk from the previous prediction, set it # to all 1s. encoder_continuous_mask = ones encoder_continuous_inputs = self.scale_features( encoder_continuous_inputs, output_range=[-1.0, 1.0], clip=True ) encodings_and_masks = self.encode( input_tokens=torch.IntTensor([encoder_input_tokens]).to(device=self.device), continuous_inputs=encoder_continuous_inputs, continuous_mask=encoder_continuous_mask, ) # Sample encoder_continuous_inputs shaped gaussian noise to begin loop x = randn_tensor( shape=encoder_continuous_inputs.shape, generator=generator, device=self.device, dtype=self.decoder.dtype, ) # set step values self.scheduler.set_timesteps(num_inference_steps) # Denoising diffusion loop for j, t in enumerate(self.progress_bar(self.scheduler.timesteps)): output = self.decode( encodings_and_masks=encodings_and_masks, input_tokens=x, noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1) ) # Compute previous output: x_t -> x_t-1 x = self.scheduler.step(output, t, x, generator=generator).prev_sample mel = self.scale_to_features(x, input_range=[-1.0, 1.0]) encoder_continuous_inputs = mel[:1] pred_mel = mel.cpu().float().numpy() full_pred_mel = np.concatenate([full_pred_mel, pred_mel[:1]], axis=1) # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, full_pred_mel) logger.info("Generated segment", i) if output_type == "numpy" and not is_onnx_available(): raise ValueError( "Cannot return output in 'np' format if ONNX is not available. Make sure to have ONNX installed or set 'output_type' to 'mel'." ) elif output_type == "numpy" and self.melgan is None: raise ValueError( "Cannot return output in 'np' format if melgan component is not defined. Make sure to define `self.melgan` or set 'output_type' to 'mel'." ) if output_type == "numpy": output = self.melgan(input_features=full_pred_mel.astype(np.float32)) else: output = full_pred_mel if not return_dict: return (output,) return AudioPipelineOutput(audios=output) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import ( BaseOutput, OptionalDependencyNotAvailable, is_flax_available, is_k_diffusion_available, is_k_diffusion_version, is_onnx_available, is_torch_available, is_transformers_available, is_transformers_version, ) @dataclass class StableDiffusionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .pipeline_stable_unclip import StableUnCLIPPipeline from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline from .safety_checker import StableDiffusionSafetyChecker from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline else: from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionDepth2ImgPipeline, StableDiffusionDiffEditPipeline, StableDiffusionPix2PixZeroPipeline, ) else: from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline try: if not ( is_torch_available() and is_transformers_available() and is_k_diffusion_available() and is_k_diffusion_version(">=", "0.0.12") ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 else: from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline try: if not (is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_onnx_objects import * # noqa F403 else: from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline if is_transformers_available() and is_flax_available(): import flax @flax.struct.dataclass class FlaxStableDiffusionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`np.ndarray`) Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content. """ images: np.ndarray nsfw_content_detected: List[bool] from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Conversion script for the Stable Diffusion checkpoints.""" import re from io import BytesIO from typing import Optional import requests import torch from transformers import ( AutoFeatureExtractor, BertTokenizerFast, CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionConfig, CLIPVisionModelWithProjection, ) from ...models import ( AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel, ) from ...schedulers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UnCLIPScheduler, ) from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging from ...utils.import_utils import BACKENDS_MAPPING from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..paint_by_example import PaintByExampleImageEncoder from ..pipeline_utils import DiffusionPipeline from .safety_checker import StableDiffusionSafetyChecker from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer if is_accelerate_available(): from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) # pylint: disable=invalid-name def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. """ if n_shave_prefix_segments >= 0: return ".".join(path.split(".")[n_shave_prefix_segments:]) else: return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item.replace("in_layers.0", "norm1") new_item = new_item.replace("in_layers.2", "conv1") new_item = new_item.replace("out_layers.0", "norm2") new_item = new_item.replace("out_layers.3", "conv2") new_item = new_item.replace("emb_layers.1", "time_emb_proj") new_item = new_item.replace("skip_connection", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside resnets to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("nin_shortcut", "conv_shortcut") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item # new_item = new_item.replace('norm.weight', 'group_norm.weight') # new_item = new_item.replace('norm.bias', 'group_norm.bias') # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): """ Updates paths inside attentions to the new naming scheme (local renaming) """ mapping = [] for old_item in old_list: new_item = old_item new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("q.weight", "to_q.weight") new_item = new_item.replace("q.bias", "to_q.bias") new_item = new_item.replace("k.weight", "to_k.weight") new_item = new_item.replace("k.bias", "to_k.bias") new_item = new_item.replace("v.weight", "to_v.weight") new_item = new_item.replace("v.bias", "to_v.bias") new_item = new_item.replace("proj_out.weight", "to_out.0.weight") new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) mapping.append({"old": old_item, "new": new_item}) return mapping def assign_to_checkpoint( paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): """ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new checkpoint. """ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." # Splits the attention layers into three variables. if attention_paths_to_split is not None: for path, path_map in attention_paths_to_split.items(): old_tensor = old_checkpoint[path] channels = old_tensor.shape[0] // 3 target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) query, key, value = old_tensor.split(channels // num_heads, dim=1) checkpoint[path_map["query"]] = query.reshape(target_shape) checkpoint[path_map["key"]] = key.reshape(target_shape) checkpoint[path_map["value"]] = value.reshape(target_shape) for path in paths: new_path = path["new"] # These have already been assigned if attention_paths_to_split is not None and new_path in attention_paths_to_split: continue # Global renaming happens here new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") if additional_replacements is not None: for replacement in additional_replacements: new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) shape = old_checkpoint[path["old"]].shape if is_attn_weight and len(shape) == 3: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] elif is_attn_weight and len(shape) == 4: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): keys = list(checkpoint.keys()) attn_keys = ["query.weight", "key.weight", "value.weight"] for key in keys: if ".".join(key.split(".")[-2:]) in attn_keys: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0, 0] elif "proj_attn.weight" in key: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0] def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): """ Creates a config for the diffusers based on the config of the LDM model. """ if controlnet: unet_params = original_config.model.params.control_stage_config.params else: if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: unet_params = original_config.model.params.unet_config.params else: unet_params = original_config.model.params.network_config.params vae_params = original_config.model.params.first_stage_config.params.ddconfig block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 if unet_params.transformer_depth is not None: transformer_layers_per_block = ( unet_params.transformer_depth if isinstance(unet_params.transformer_depth, int) else list(unet_params.transformer_depth) ) else: transformer_layers_per_block = 1 vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) head_dim = unet_params.num_heads if "num_heads" in unet_params else None use_linear_projection = ( unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: head_dim_mult = unet_params.model_channels // unet_params.num_head_channels head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] class_embed_type = None addition_embed_type = None addition_time_embed_dim = None projection_class_embeddings_input_dim = None context_dim = None if unet_params.context_dim is not None: context_dim = ( unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] ) if "num_classes" in unet_params: if unet_params.num_classes == "sequential": if context_dim in [2048, 1280]: # SDXL addition_embed_type = "text_time" addition_time_embed_dim = 256 else: class_embed_type = "projection" assert "adm_in_channels" in unet_params projection_class_embeddings_input_dim = unet_params.adm_in_channels else: raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") config = { "sample_size": image_size // vae_scale_factor, "in_channels": unet_params.in_channels, "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), "layers_per_block": unet_params.num_res_blocks, "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "addition_embed_type": addition_embed_type, "addition_time_embed_dim": addition_time_embed_dim, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, "transformer_layers_per_block": transformer_layers_per_block, } if controlnet: config["conditioning_channels"] = unet_params.hint_channels else: config["out_channels"] = unet_params.out_channels config["up_block_types"] = tuple(up_block_types) return config def create_vae_diffusers_config(original_config, image_size: int): """ Creates a config for the diffusers based on the config of the LDM model. """ vae_params = original_config.model.params.first_stage_config.params.ddconfig _ = original_config.model.params.first_stage_config.params.embed_dim block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, "in_channels": vae_params.in_channels, "out_channels": vae_params.out_ch, "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), "latent_channels": vae_params.z_channels, "layers_per_block": vae_params.num_res_blocks, } return config def create_diffusers_schedular(original_config): schedular = DDIMScheduler( num_train_timesteps=original_config.model.params.timesteps, beta_start=original_config.model.params.linear_start, beta_end=original_config.model.params.linear_end, beta_schedule="scaled_linear", ) return schedular def create_ldm_bert_config(original_config): bert_params = original_config.model.parms.cond_stage_config.params config = LDMBertConfig( d_model=bert_params.n_embed, encoder_layers=bert_params.n_layer, encoder_ffn_dim=bert_params.n_embed * 4, ) return config def convert_ldm_unet_checkpoint( checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False ): """ Takes a state dict and a config, and returns a converted checkpoint. """ if skip_extract_state_dict: unet_state_dict = checkpoint else: # extract state_dict for UNet unet_state_dict = {} keys = list(checkpoint.keys()) if controlnet: unet_key = "control_model." else: unet_key = "model.diffusion_model." # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") logger.warning( "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." ) for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) else: if sum(k.startswith("model_ema") for k in keys) > 100: logger.warning( "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" " weights (usually better for inference), please make sure to add the `--extract_ema` flag." ) for key in keys: if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] if config["class_embed_type"] is None: # No parameters to port ... elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] else: raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") if config["addition_embed_type"] == "text_time": new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] if not controlnet: new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) input_blocks = { layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) middle_blocks = { layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) output_blocks = { layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] for layer_id in range(num_output_blocks) } for i in range(1, num_input_blocks): block_id = (i - 1) // (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.weight" ) new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( f"input_blocks.{i}.0.op.bias" ) paths = renew_resnet_paths(resnets) meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) if len(attentions): paths = renew_attention_paths(attentions) meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) resnet_0 = middle_blocks[0] attentions = middle_blocks[1] resnet_1 = middle_blocks[2] resnet_0_paths = renew_resnet_paths(resnet_0) assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) resnet_1_paths = renew_resnet_paths(resnet_1) assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) for i in range(num_output_blocks): block_id = i // (config["layers_per_block"] + 1) layer_in_block_id = i % (config["layers_per_block"] + 1) output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] output_block_list = {} for layer in output_block_layers: layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) if layer_id in output_block_list: output_block_list[layer_id].append(layer_name) else: output_block_list[layer_id] = [layer_name] if len(output_block_list) > 1: resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.weight" ] new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"output_blocks.{i}.{index}.conv.bias" ] # Clear attentions as they have been attributed above. if len(attentions) == 2: attentions = [] if len(attentions): paths = renew_attention_paths(attentions) meta_path = { "old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) else: resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) new_checkpoint[new_path] = unet_state_dict[old_path] if controlnet: # conditioning embedding orig_index = 0 new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) orig_index += 2 diffusers_index = 0 while diffusers_index < 6: new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) diffusers_index += 1 orig_index += 2 new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.weight" ) new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( f"input_hint_block.{orig_index}.bias" ) # down blocks for i in range(num_input_blocks): new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") # mid block new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config): # extract state dict for VAE vae_state_dict = {} vae_key = "first_stage_model." keys = list(checkpoint.keys()) for key in keys: if key.startswith(vae_key): vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) new_checkpoint = {} new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] # Retrieves the keys for the encoder down blocks only num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) down_blocks = { layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) up_blocks = { layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) } for i in range(num_down_blocks): resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.weight" ) new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( f"encoder.down.{i}.downsample.conv.bias" ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.weight" ] new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ f"decoder.up.{block_id}.upsample.conv.bias" ] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 for i in range(1, num_mid_res_blocks + 1): resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) return new_checkpoint def convert_ldm_bert_checkpoint(checkpoint, config): def _copy_attn_layer(hf_attn_layer, pt_attn_layer): hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias def _copy_linear(hf_linear, pt_linear): hf_linear.weight = pt_linear.weight hf_linear.bias = pt_linear.bias def _copy_layer(hf_layer, pt_layer): # copy layer norms _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) # copy attn _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) # copy MLP pt_mlp = pt_layer[1][1] _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) _copy_linear(hf_layer.fc2, pt_mlp.net[2]) def _copy_layers(hf_layers, pt_layers): for i, hf_layer in enumerate(hf_layers): if i != 0: i += i pt_layer = pt_layers[i : i + 2] _copy_layer(hf_layer, pt_layer) hf_model = LDMBertModel(config).eval() # copy embeds hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight # copy layer norm _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) # copy hidden layers _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) return hf_model def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): if text_encoder is None: config_name = "openai/clip-vit-large-patch14" config = CLIPTextConfig.from_pretrained(config_name) with init_empty_weights(): text_model = CLIPTextModel(config) keys = list(checkpoint.keys()) text_model_dict = {} remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] for key in keys: for prefix in remove_prefixes: if key.startswith(prefix): text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] for param_name, param in text_model_dict.items(): set_module_tensor_to_device(text_model, param_name, "cpu", value=param) return text_model textenc_conversion_lst = [ ("positional_embedding", "text_model.embeddings.position_embedding.weight"), ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), ("ln_final.weight", "text_model.final_layer_norm.weight"), ("ln_final.bias", "text_model.final_layer_norm.bias"), ("text_projection", "text_projection.weight"), ] textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} textenc_transformer_conversion_lst = [ # (stable-diffusion, HF Diffusers) ("resblocks.", "text_model.encoder.layers."), ("ln_1", "layer_norm1"), ("ln_2", "layer_norm2"), (".c_fc.", ".fc1."), (".c_proj.", ".fc2."), (".attn", ".self_attn"), ("ln_final.", "transformer.text_model.final_layer_norm."), ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), ] protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) def convert_paint_by_example_checkpoint(checkpoint): config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") model = PaintByExampleImageEncoder(config) keys = list(checkpoint.keys()) text_model_dict = {} for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] # load clip vision model.model.load_state_dict(text_model_dict) # load mapper keys_mapper = { k[len("cond_stage_model.mapper.res") :]: v for k, v in checkpoint.items() if k.startswith("cond_stage_model.mapper") } MAPPING = { "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], "attn.c_proj": ["attn1.to_out.0"], "ln_1": ["norm1"], "ln_2": ["norm3"], "mlp.c_fc": ["ff.net.0.proj"], "mlp.c_proj": ["ff.net.2"], } mapped_weights = {} for key, value in keys_mapper.items(): prefix = key[: len("blocks.i")] suffix = key.split(prefix)[-1].split(".")[-1] name = key.split(prefix)[-1].split(suffix)[0][1:-1] mapped_names = MAPPING[name] num_splits = len(mapped_names) for i, mapped_name in enumerate(mapped_names): new_name = ".".join([prefix, mapped_name, suffix]) shape = value.shape[0] // num_splits mapped_weights[new_name] = value[i * shape : (i + 1) * shape] model.mapper.load_state_dict(mapped_weights) # load final layer norm model.final_layer_norm.load_state_dict( { "bias": checkpoint["cond_stage_model.final_ln.bias"], "weight": checkpoint["cond_stage_model.final_ln.weight"], } ) # load final proj model.proj_out.load_state_dict( { "bias": checkpoint["proj_out.bias"], "weight": checkpoint["proj_out.weight"], } ) # load uncond vector model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) return model def convert_open_clip_checkpoint( checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs ): # text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") # text_model = CLIPTextModelWithProjection.from_pretrained( # "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280 # ) config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs) with init_empty_weights(): text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) keys = list(checkpoint.keys()) keys_to_ignore = [] if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23: # make sure to remove all keys > 22 keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")] keys_to_ignore += ["cond_stage_model.model.text_projection"] text_model_dict = {} if prefix + "text_projection" in checkpoint: d_model = int(checkpoint[prefix + "text_projection"].shape[0]) else: d_model = 1024 text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") for key in keys: if key in keys_to_ignore: continue if key[len(prefix) :] in textenc_conversion_map: if key.endswith("text_projection"): value = checkpoint[key].T else: value = checkpoint[key] text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value if key.startswith(prefix + "transformer."): new_key = key[len(prefix + "transformer.") :] if new_key.endswith(".in_proj_weight"): new_key = new_key[: -len(".in_proj_weight")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] elif new_key.endswith(".in_proj_bias"): new_key = new_key[: -len(".in_proj_bias")] new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] else: new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) text_model_dict[new_key] = checkpoint[key] for param_name, param in text_model_dict.items(): set_module_tensor_to_device(text_model, param_name, "cpu", value=param) return text_model def stable_unclip_image_encoder(original_config): """ Returns the image processor and clip image encoder for the img2img unclip pipeline. We currently know of two types of stable unclip models which separately use the clip and the openclip image encoders. """ image_embedder_config = original_config.model.params.embedder_config sd_clip_image_embedder_class = image_embedder_config.target sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] if sd_clip_image_embedder_class == "ClipImageEmbedder": clip_model_name = image_embedder_config.params.model if clip_model_name == "ViT-L/14": feature_extractor = CLIPImageProcessor() image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") else: raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": feature_extractor = CLIPImageProcessor() image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") else: raise NotImplementedError( f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" ) return feature_extractor, image_encoder def stable_unclip_image_noising_components( original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None ): """ Returns the noising components for the img2img and txt2img unclip pipelines. Converts the stability noise augmentor into 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats 2. a `DDPMScheduler` for holding the noise schedule If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. """ noise_aug_config = original_config.model.params.noise_aug_config noise_aug_class = noise_aug_config.target noise_aug_class = noise_aug_class.split(".")[-1] if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": noise_aug_config = noise_aug_config.params embedding_dim = noise_aug_config.timestep_dim max_noise_level = noise_aug_config.noise_schedule_config.timesteps beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) if "clip_stats_path" in noise_aug_config: if clip_stats_path is None: raise ValueError("This stable unclip config requires a `clip_stats_path`") clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) clip_mean = clip_mean[None, :] clip_std = clip_std[None, :] clip_stats_state_dict = { "mean": clip_mean, "std": clip_std, } image_normalizer.load_state_dict(clip_stats_state_dict) else: raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") return image_normalizer, image_noising_scheduler def convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema, use_linear_projection=None, cross_attention_dim=None, ): ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config.pop("sample_size") if use_linear_projection is not None: ctrlnet_config["use_linear_projection"] = use_linear_projection if cross_attention_dim is not None: ctrlnet_config["cross_attention_dim"] = cross_attention_dim controlnet_model = ControlNetModel(**ctrlnet_config) # Some controlnet ckpt files are distributed independently from the rest of the # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ if "time_embed.0.weight" in checkpoint: skip_extract_state_dict = True else: skip_extract_state_dict = False converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True, skip_extract_state_dict=skip_extract_state_dict, ) controlnet_model.load_state_dict(converted_ctrl_checkpoint) return controlnet_model def download_from_original_stable_diffusion_ckpt( checkpoint_path: str, original_config_file: str = None, image_size: Optional[int] = None, prediction_type: str = None, model_type: str = None, extract_ema: bool = False, scheduler_type: str = "pndm", num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, device: str = None, from_safetensors: bool = False, stable_unclip: Optional[str] = None, stable_unclip_prior: Optional[str] = None, clip_stats_path: Optional[str] = None, controlnet: Optional[bool] = None, load_safety_checker: bool = True, pipeline_class: DiffusionPipeline = None, local_files_only=False, vae_path=None, text_encoder=None, tokenizer=None, ) -> DiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` config file. Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is recommended that you override the default values and/or supply an `original_config_file` wherever possible. Args: checkpoint_path (`str`): Path to `.ckpt` file. original_config_file (`str`): Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models. image_size (`int`, *optional*, defaults to 512): The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 Base. Use 768 for Stable Diffusion v2. prediction_type (`str`, *optional*): The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. num_in_channels (`int`, *optional*, defaults to None): The number of input channels. If `None`, it will be automatically inferred. scheduler_type (`str`, *optional*, defaults to 'pndm'): Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", "ddim"]`. model_type (`str`, *optional*, defaults to `None`): The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. is_img2img (`bool`, *optional*, defaults to `False`): Whether the model should be loaded as an img2img pipeline. extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning. upcast_attention (`bool`, *optional*, defaults to `None`): Whether the attention computation should always be upcasted. This is necessary when running stable diffusion 2.1. device (`str`, *optional*, defaults to `None`): The device to use. Pass `None` to determine automatically. from_safetensors (`str`, *optional*, defaults to `False`): If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. load_safety_checker (`bool`, *optional*, defaults to `True`): Whether to load the safety checker or not. Defaults to `True`. pipeline_class (`str`, *optional*, defaults to `None`): The pipeline class to use. Pass `None` to determine automatically. local_files_only (`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). text_encoder (`CLIPTextModel`, *optional*, defaults to `None`): An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`): An instance of [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if needed. return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ # import pipelines here to avoid circular import error when using from_single_file method from diffusers import ( LDMTextToImagePipeline, PaintByExamplePipeline, StableDiffusionControlNetPipeline, StableDiffusionInpaintPipeline, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) if pipeline_class is None: pipeline_class = StableDiffusionPipeline if prediction_type == "v-prediction": prediction_type = "v_prediction" if not is_omegaconf_available(): raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) from omegaconf import OmegaConf if from_safetensors: if not is_safetensors_available(): raise ValueError(BACKENDS_MAPPING["safetensors"][1]) from safetensors.torch import load_file as safe_load checkpoint = safe_load(checkpoint_path, device="cpu") else: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) # Sometimes models don't have the global_step item if "global_step" in checkpoint: global_step = checkpoint["global_step"] else: logger.debug("global_step key not found in model") global_step = None # NOTE: this while loop isn't great but this controlnet checkpoint has one additional # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] if original_config_file is None: key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" # model_type = "v1" config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: # model_type = "v2" config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" if global_step == 110000: # v2.1 needs to upcast attention upcast_attention = True elif key_name_sd_xl_base in checkpoint: # only base xl has two text embedders config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" elif key_name_sd_xl_refiner in checkpoint: # only refiner xl has embedder and one text embedders config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" original_config_file = BytesIO(requests.get(config_url).content) original_config = OmegaConf.load(original_config_file) # Convert the text model. if ( model_type is None and "cond_stage_config" in original_config.model.params and original_config.model.params.cond_stage_config is not None ): model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") elif model_type is None and original_config.model.params.network_config is not None: if original_config.model.params.network_config.params.context_dim == 2048: model_type = "SDXL" else: model_type = "SDXL-Refiner" if image_size is None: image_size = 1024 if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: num_in_channels = 9 elif num_in_channels is None: num_in_channels = 4 if "unet_config" in original_config.model.params: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if ( "parameterization" in original_config["model"]["params"] and original_config["model"]["params"]["parameterization"] == "v" ): if prediction_type is None: # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here prediction_type = "epsilon" if global_step == 875000 else "v_prediction" if image_size is None: # NOTE: For stable diffusion 2 base one has to pass `image_size==512` # as it relies on a brittle global step parameter here image_size = 512 if global_step == 875000 else 768 else: if prediction_type is None: prediction_type = "epsilon" if image_size is None: image_size = 512 if controlnet is None: controlnet = "control_stage_config" in original_config.model.params if controlnet: controlnet_model = convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema ) num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 if model_type in ["SDXL", "SDXL-Refiner"]: scheduler_dict = { "beta_schedule": "scaled_linear", "beta_start": 0.00085, "beta_end": 0.012, "interpolation_type": "linear", "num_train_timesteps": num_train_timesteps, "prediction_type": "epsilon", "sample_max_value": 1.0, "set_alpha_to_one": False, "skip_prk_steps": True, "steps_offset": 1, "timestep_spacing": "leading", } scheduler = EulerDiscreteScheduler.from_config(scheduler_dict) scheduler_type = "euler" else: beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 scheduler = DDIMScheduler( beta_end=beta_end, beta_schedule="scaled_linear", beta_start=beta_start, num_train_timesteps=num_train_timesteps, steps_offset=1, clip_sample=False, set_alpha_to_one=False, prediction_type=prediction_type, ) # make sure scheduler works correctly with DDIM scheduler.register_to_config(clip_sample=False) if scheduler_type == "pndm": config = dict(scheduler.config) config["skip_prk_steps"] = True scheduler = PNDMScheduler.from_config(config) elif scheduler_type == "lms": scheduler = LMSDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "heun": scheduler = HeunDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "euler": scheduler = EulerDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "euler-ancestral": scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config) elif scheduler_type == "dpm": scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) elif scheduler_type == "ddim": scheduler = scheduler else: raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config["upcast_attention"] = upcast_attention with init_empty_weights(): unet = UNet2DConditionModel(**unet_config) converted_unet_checkpoint = convert_ldm_unet_checkpoint( checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema ) for param_name, param in converted_unet_checkpoint.items(): set_module_tensor_to_device(unet, param_name, "cpu", value=param) # Convert the VAE model. if vae_path is None: vae_config = create_vae_diffusers_config(original_config, image_size=image_size) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) if ( "model" in original_config and "params" in original_config.model and "scale_factor" in original_config.model.params ): vae_scaling_factor = original_config.model.params.scale_factor else: vae_scaling_factor = 0.18215 # default SD scaling factor vae_config["scaling_factor"] = vae_scaling_factor with init_empty_weights(): vae = AutoencoderKL(**vae_config) for param_name, param in converted_vae_checkpoint.items(): set_module_tensor_to_device(vae, param_name, "cpu", value=param) else: vae = AutoencoderKL.from_pretrained(vae_path) if model_type == "FrozenOpenCLIPEmbedder": config_name = "stabilityai/stable-diffusion-2" config_kwargs = {"subfolder": "text_encoder"} text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") if stable_unclip is None: if controlnet: pipe = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, controlnet=controlnet_model, safety_checker=None, feature_extractor=None, requires_safety_checker=False, ) else: pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, ) else: image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( original_config, clip_stats_path=clip_stats_path, device=device ) if stable_unclip == "img2img": feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) pipe = StableUnCLIPImg2ImgPipeline( # image encoding components feature_extractor=feature_extractor, image_encoder=image_encoder, # image noising components image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, # regular denoising components tokenizer=tokenizer, text_encoder=text_model, unet=unet, scheduler=scheduler, # vae vae=vae, ) elif stable_unclip == "txt2img": if stable_unclip_prior is None or stable_unclip_prior == "karlo": karlo_model = "kakaobrain/karlo-v1-alpha" prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) else: raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}") pipe = StableUnCLIPPipeline( # prior components prior_tokenizer=prior_tokenizer, prior_text_encoder=prior_text_model, prior=prior, prior_scheduler=prior_scheduler, # image noising components image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, # regular denoising components tokenizer=tokenizer, text_encoder=text_model, unet=unet, scheduler=scheduler, # vae vae=vae, ) else: raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}") elif model_type == "PaintByExample": vision_model = convert_paint_by_example_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") pipe = PaintByExamplePipeline( vae=vae, image_encoder=vision_model, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor=feature_extractor, ) elif model_type == "FrozenCLIPEmbedder": text_model = convert_ldm_clip_checkpoint( checkpoint, local_files_only=local_files_only, text_encoder=text_encoder ) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if tokenizer is None else tokenizer if load_safety_checker: safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") else: safety_checker = None feature_extractor = None if controlnet: pipe = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, controlnet=controlnet_model, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) else: pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) elif model_type in ["SDXL", "SDXL-Refiner"]: if model_type == "SDXL": tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" config_kwargs = {"projection_dim": 1280} text_encoder_2 = convert_open_clip_checkpoint( checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs ) pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, force_zeros_for_empty_prompt=True, ) else: tokenizer = None text_encoder = None tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" config_kwargs = {"projection_dim": 1280} text_encoder_2 = convert_open_clip_checkpoint( checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs ) pipe = StableDiffusionXLImg2ImgPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, requires_aesthetics_score=True, force_zeros_for_empty_prompt=False, ) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) return pipe def download_controlnet_from_original_ckpt( checkpoint_path: str, original_config_file: str, image_size: int = 512, extract_ema: bool = False, num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, device: str = None, from_safetensors: bool = False, use_linear_projection: Optional[bool] = None, cross_attention_dim: Optional[bool] = None, ) -> DiffusionPipeline: if not is_omegaconf_available(): raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) from omegaconf import OmegaConf if from_safetensors: if not is_safetensors_available(): raise ValueError(BACKENDS_MAPPING["safetensors"][1]) from safetensors import safe_open checkpoint = {} with safe_open(checkpoint_path, framework="pt", device="cpu") as f: for key in f.keys(): checkpoint[key] = f.get_tensor(key) else: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(checkpoint_path, map_location=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) # NOTE: this while loop isn't great but this controlnet checkpoint has one additional # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] original_config = OmegaConf.load(original_config_file) if num_in_channels is not None: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels if "control_stage_config" not in original_config.model.params: raise ValueError("`control_stage_config` not present in original config") controlnet_model = convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema, use_linear_projection=use_linear_projection, cross_attention_dim=cross_attention_dim, ) return controlnet_model ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): # 1. get previous step value (=t-1) prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps if prev_timestep <= 0: return clean_latents # 2. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = ( scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod ) variance = scheduler._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # direction pointing to x_t e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5) dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t noise = std_dev_t * randn_tensor( clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator ) prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise return prev_latents def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): # 1. get previous step value (=t-1) prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps # 2. compute alphas, betas alpha_prod_t = scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = ( scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) # 4. Clip "predicted x_0" if scheduler.config.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = scheduler._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / ( variance ** (0.5) * eta ) return noise class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): image = image.to(device=device, dtype=dtype) batch_size = image.shape[0] if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) # add noise to latents using the timestep shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents clean_latents = init_latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents, clean_latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], source_prompt: Union[str, List[str]], image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, source_guidance_scale: Optional[float] = 1, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. source_guidance_scale (`float`, *optional*, defaults to 1): Guidance scale for the source prompt. This is useful to control the amount of influence the source prompt for encoding. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.1): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs(prompt, strength, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, prompt_embeds=prompt_embeds, lora_scale=text_encoder_lora_scale, ) source_prompt_embeds = self._encode_prompt( source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents, clean_latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) source_latents = latents # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) generator = extra_step_kwargs.pop("generator", None) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) source_latent_model_input = torch.cat([source_latents] * 2) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) # predict the noise residual concat_latent_model_input = torch.stack( [ source_latent_model_input[0], latent_model_input[0], source_latent_model_input[1], latent_model_input[1], ], dim=0, ) concat_prompt_embeds = torch.stack( [ source_prompt_embeds[0], prompt_embeds[0], source_prompt_embeds[1], prompt_embeds[1], ], dim=0, ) concat_noise_pred = self.unet( concat_latent_model_input, t, cross_attention_kwargs=cross_attention_kwargs, encoder_hidden_states=concat_prompt_embeds, ).sample # perform guidance ( source_noise_pred_uncond, noise_pred_uncond, source_noise_pred_text, noise_pred_text, ) = concat_noise_pred.chunk(4, dim=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( source_noise_pred_text - source_noise_pred_uncond ) # Sample source_latents from the posterior distribution. prev_source_latents = posterior_sample( self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs ) # Compute noise. noise = compute_noise( self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs ) source_latents = prev_source_latents # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs ).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 9. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from packaging import version from PIL import Image from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import deprecate, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> from diffusers import FlaxStableDiffusionPipeline >>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", revision="bf16", dtype=jax.numpy.bfloat16 ... ) >>> prompt = "a photo of an astronaut riding a horse on mars" >>> prng_seed = jax.random.PRNGKey(0) >>> num_inference_steps = 50 >>> num_samples = jax.device_count() >>> prompt = num_samples * [prompt] >>> prompt_ids = pipeline.prepare_inputs(prompt) # shard inputs and rng >>> params = replicate(params) >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) >>> prompt_ids = shard(prompt_ids) >>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) ``` """ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`FlaxCLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def _generate( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int, height: int, width: int, guidance_scale: float, latents: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.array] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) # Ensure model output will be `float32` before going into the scheduler guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) latents_shape = ( batch_size, self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, neg_prompt_ids: jnp.array = None, return_dict: bool = True, jit: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. latents (`jnp.array`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. tensor will ge generated by sampling using the supplied random `generator`. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] if jit: images = _p_generate( self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) else: images = self._generate( prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.asarray(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), static_broadcasted_argnums=(0, 4, 5, 6), ) def _p_generate( pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ): return pipe._generate( prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: This file is deprecated and will be removed in a future version. # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works from ...utils import deprecate from ..controlnet.pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline # noqa: F401 deprecate( "stable diffusion controlnet", "0.22.0", "Importing `FlaxStableDiffusionControlNetPipeline` from diffusers.pipelines.stable_diffusion.flax_pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import FlaxStableDiffusionControlNetPipeline` instead.", standard_warn=False, stacklevel=3, ) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from PIL import Image from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> import jax.numpy as jnp >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> from diffusers import FlaxStableDiffusionImg2ImgPipeline >>> def create_key(seed=0): ... return jax.random.PRNGKey(seed) >>> rng = create_key(0) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_img = Image.open(BytesIO(response.content)).convert("RGB") >>> init_img = init_img.resize((768, 512)) >>> prompts = "A fantasy landscape, trending on artstation" >>> pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", ... revision="flax", ... dtype=jnp.bfloat16, ... ) >>> num_samples = jax.device_count() >>> rng = jax.random.split(rng, jax.device_count()) >>> prompt_ids, processed_image = pipeline.prepare_inputs( ... prompt=[prompts] * num_samples, image=[init_img] * num_samples ... ) >>> p_params = replicate(params) >>> prompt_ids = shard(prompt_ids) >>> processed_image = shard(processed_image) >>> output = pipeline( ... prompt_ids=prompt_ids, ... image=processed_image, ... params=p_params, ... prng_seed=rng, ... strength=0.75, ... num_inference_steps=50, ... jit=True, ... height=512, ... width=768, ... ).images >>> output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) ``` """ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): r""" Pipeline for image-to-image generation using Stable Diffusion. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`FlaxCLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warn( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if not isinstance(image, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(image, Image.Image): image = [image] processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids, processed_images def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def get_timestep_start(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) return t_start def _generate( self, prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, start_timestep: int, num_inference_steps: int, height: int, width: int, guidance_scale: float, noise: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.array] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) latents_shape = ( batch_size, self.unet.config.in_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if noise is None: noise = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if noise.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {noise.shape}, expected {latents_shape}") # Create init_latents init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) init_latents = self.vae.config.scaling_factor * init_latents def loop_body(step, args): latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape ) latent_timestep = scheduler_state.timesteps[start_timestep : start_timestep + 1].repeat(batch_size) latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(start_timestep, num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) else: latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, strength: float = 0.8, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.array] = 7.5, noise: jnp.array = None, neg_prompt_ids: jnp.array = None, return_dict: bool = True, jit: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt_ids (`jnp.array`): The prompt or prompts to guide the image generation. image (`jnp.array`): Array representing an image batch, that will be used as the starting point for the process. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. noise (`jnp.array`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. tensor will ge generated by sampling using the supplied random `generator`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] start_timestep = self.get_timestep_start(num_inference_steps, strength) if jit: images = _p_generate( self, prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ) else: images = self._generate( prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.asarray(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, start_timestep, num_inference_steps, height, width. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, 0, None, None, None, None, 0, 0, 0), static_broadcasted_argnums=(0, 5, 6, 7, 8), ) def _p_generate( pipe, prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ): return pipe._generate( prompt_ids, image, params, prng_seed, start_timestep, num_inference_steps, height, width, guidance_scale, noise, neg_prompt_ids, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) def preprocess(image, dtype): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from functools import partial from typing import Dict, List, Optional, Union import jax import jax.numpy as jnp import numpy as np from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard from packaging import version from PIL import Image from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...schedulers import ( FlaxDDIMScheduler, FlaxDPMSolverMultistepScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring from ..pipeline_flax_utils import FlaxDiffusionPipeline from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False EXAMPLE_DOC_STRING = """ Examples: ```py >>> import jax >>> import numpy as np >>> from flax.jax_utils import replicate >>> from flax.training.common_utils import shard >>> import PIL >>> import requests >>> from io import BytesIO >>> from diffusers import FlaxStableDiffusionInpaintPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" >>> init_image = download_image(img_url).resize((512, 512)) >>> mask_image = download_image(mask_url).resize((512, 512)) >>> pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained( ... "xvjiarui/stable-diffusion-2-inpainting" ... ) >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" >>> prng_seed = jax.random.PRNGKey(0) >>> num_inference_steps = 50 >>> num_samples = jax.device_count() >>> prompt = num_samples * [prompt] >>> init_image = num_samples * [init_image] >>> mask_image = num_samples * [mask_image] >>> prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs( ... prompt, init_image, mask_image ... ) # shard inputs and rng >>> params = replicate(params) >>> prng_seed = jax.random.split(prng_seed, jax.device_count()) >>> prompt_ids = shard(prompt_ids) >>> processed_masked_images = shard(processed_masked_images) >>> processed_masks = shard(processed_masks) >>> images = pipeline( ... prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True ... ).images >>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) ``` """ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`FlaxAutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`FlaxCLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or [`FlaxDPMSolverMultistepScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ def __init__( self, vae: FlaxAutoencoderKL, text_encoder: FlaxCLIPTextModel, tokenizer: CLIPTokenizer, unet: FlaxUNet2DConditionModel, scheduler: Union[ FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler ], safety_checker: FlaxStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, dtype: jnp.dtype = jnp.float32, ): super().__init__() self.dtype = dtype if safety_checker is None: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs( self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]], mask: Union[Image.Image, List[Image.Image]], ): if not isinstance(prompt, (str, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if not isinstance(image, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(image, Image.Image): image = [image] if not isinstance(mask, (Image.Image, list)): raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") if isinstance(mask, Image.Image): mask = [mask] processed_images = jnp.concatenate([preprocess_image(img, jnp.float32) for img in image]) processed_masks = jnp.concatenate([preprocess_mask(m, jnp.float32) for m in mask]) # processed_masks[processed_masks < 0.5] = 0 processed_masks = processed_masks.at[processed_masks < 0.5].set(0) # processed_masks[processed_masks >= 0.5] = 1 processed_masks = processed_masks.at[processed_masks >= 0.5].set(1) processed_masked_images = processed_images * (processed_masks < 0.5) text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) return text_input.input_ids, processed_masked_images, processed_masks def _get_has_nsfw_concepts(self, features, params): has_nsfw_concepts = self.safety_checker(features, params) return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True pil_images = [Image.fromarray(image) for image in images] features = self.feature_extractor(pil_images, return_tensors="np").pixel_values if jit: features = shard(features) has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) images_was_copied = False for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if not images_was_copied: images_was_copied = True images = images.copy() images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image if any(has_nsfw_concepts): warnings.warn( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts def _generate( self, prompt_ids: jnp.array, mask: jnp.array, masked_image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int, height: int, width: int, guidance_scale: float, latents: Optional[jnp.array] = None, neg_prompt_ids: Optional[jnp.array] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] if neg_prompt_ids is None: uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" ).input_ids else: uncond_input = neg_prompt_ids negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) latents_shape = ( batch_size, self.vae.config.latent_channels, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") prng_seed, mask_prng_seed = jax.random.split(prng_seed) masked_image_latent_dist = self.vae.apply( {"params": params["vae"]}, masked_image, method=self.vae.encode ).latent_dist masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2)) masked_image_latents = self.vae.config.scaling_factor * masked_image_latents del mask_prng_seed mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest") # 8. Check that sizes of mask, masked image and latents match num_channels_latents = self.vae.config.latent_channels num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) def loop_body(step, args): latents, mask, masked_image_latents, scheduler_state = args # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes latents_input = jnp.concatenate([latents] * 2) mask_input = jnp.concatenate([mask] * 2) masked_image_latents_input = jnp.concatenate([masked_image_latents] * 2) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] timestep = jnp.broadcast_to(t, latents_input.shape[0]) latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) # concat latents, mask, masked_image_latents in the channel dimension latents_input = jnp.concatenate([latents_input, mask_input, masked_image_latents_input], axis=1) # predict the noise residual noise_pred = self.unet.apply( {"params": params["unet"]}, jnp.array(latents_input), jnp.array(timestep, dtype=jnp.int32), encoder_hidden_states=context, ).sample # perform guidance noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents, mask, masked_image_latents, scheduler_state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape ) # scale the initial noise by the standard deviation required by the scheduler latents = latents * params["scheduler"].init_noise_sigma if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, mask, masked_image_latents, scheduler_state = loop_body( i, (latents, mask, masked_image_latents, scheduler_state) ) else: latents, _, _, _ = jax.lax.fori_loop( 0, num_inference_steps, loop_body, (latents, mask, masked_image_latents, scheduler_state) ) # scale and decode the image latents with vae latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) return image @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt_ids: jnp.array, mask: jnp.array, masked_image: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.KeyArray, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, neg_prompt_ids: jnp.array = None, return_dict: bool = True, jit: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. latents (`jnp.array`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. tensor will ge generated by sampling using the supplied random `generator`. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of a plain tuple. Examples: Returns: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor masked_image = jax.image.resize(masked_image, (*masked_image.shape[:-2], height, width), method="bicubic") mask = jax.image.resize(mask, (*mask.shape[:-2], height, width), method="nearest") if isinstance(guidance_scale, float): # Convert to a tensor so each device gets a copy. Follow the prompt_ids for # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded guidance_scale = guidance_scale[:, None] if jit: images = _p_generate( self, prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) else: images = self._generate( prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) if self.safety_checker is not None: safety_params = params["safety_checker"] images_uint8_casted = (images * 255).round().astype("uint8") num_devices, batch_size = images.shape[:2] images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) images = np.asarray(images) # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: images = np.asarray(images) has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) # Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, 0, 0, None, None, None, 0, 0, 0), static_broadcasted_argnums=(0, 6, 7, 8), ) def _p_generate( pipe, prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ): return pipe._generate( prompt_ids, mask, masked_image, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, neg_prompt_ids, ) @partial(jax.pmap, static_broadcasted_argnums=(0,)) def _p_get_has_nsfw_concepts(pipe, features, params): return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): # einops.rearrange(x, 'd b ... -> (d b) ...') num_devices, batch_size = x.shape[:2] rest = x.shape[2:] return x.reshape(num_devices * batch_size, *rest) def preprocess_image(image, dtype): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = jnp.array(image).astype(dtype) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 def preprocess_mask(mask, dtype): w, h = mask.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = mask.resize((w, h)) mask = jnp.array(mask.convert("L")).astype(dtype) / 255.0 mask = jnp.expand_dims(mask, axis=(0, 1)) return mask ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import numpy as np import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) class OnnxStableDiffusionPipeline(DiffusionPipeline): vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def check_inputs( self, prompt: Union[str, List[str]], height: Optional[int], width: Optional[int], callback_steps: int, negative_prompt: Optional[str] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): `Image`, or tensor representing an image batch which will be upscaled. * num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): One or a list of [numpy generator(s)](TODO) to make generation deterministic. latents (`np.ndarray`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # get the initial random noise unless the user supplied it latents_dtype = prompt_embeds.dtype latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) elif latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # set timesteps self.scheduler.set_timesteps(num_inference_steps) latents = latents * np.float64(self.scheduler.init_noise_sigma) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = latent_model_input.cpu().numpy() # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) noise_pred = noise_pred[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ) latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) class StableDiffusionOnnxPipeline(OnnxStableDiffusionPipeline): def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, ): deprecation_message = "Please use `OnnxStableDiffusionPipeline` instead of `StableDiffusionOnnxPipeline`." deprecate("StableDiffusionOnnxPipeline", "1.0.0", deprecation_message) super().__init__( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 def preprocess(image): warnings.warn( ( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead" ), FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def check_inputs( self, prompt: Union[str, List[str]], callback_steps: int, negative_prompt: Optional[Union[str, List[str]]] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def __call__( self, prompt: Union[str, List[str]], image: Union[np.ndarray, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`np.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if generator is None: generator = np.random # set timesteps self.scheduler.set_timesteps(num_inference_steps) image = preprocess(image).cpu().numpy() # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype image = image.astype(latents_dtype) # encode the init image into latents and scale the latents init_latents = self.vae_encoder(sample=image)[0] init_latents = 0.18215 * init_latents if isinstance(prompt, str): prompt = [prompt] if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = len(prompt) // init_latents.shape[0] init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." ) else: init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps.numpy()[-init_timestep] timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) # add noise to latents using the timesteps noise = generator.randn(*init_latents.shape).astype(latents_dtype) init_latents = self.scheduler.add_noise( torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) ) init_latents = init_latents.numpy() # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:].numpy() timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = latent_model_input.cpu().numpy() # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ 0 ] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ) latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) # safety_checker does not support batched inputs yet images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name NUM_UNET_INPUT_CHANNELS = 9 NUM_LATENT_CHANNELS = 4 def prepare_mask_and_masked_image(image, mask, latents_shape): image = np.array(image.convert("RGB").resize((latents_shape[1] * 8, latents_shape[0] * 8))) image = image[None].transpose(0, 3, 1, 2) image = image.astype(np.float32) / 127.5 - 1.0 image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) masked_image = image * (image_mask < 127.5) mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) mask = np.array(mask.convert("L")) mask = mask.astype(np.float32) / 255.0 mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 return mask, masked_image class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs def check_inputs( self, prompt: Union[str, List[str]], height: Optional[int], width: Optional[int], callback_steps: int, negative_prompt: Optional[str] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], image: PIL.Image.Image, mask_image: PIL.Image.Image, height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[np.random.RandomState] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. latents (`np.ndarray`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if generator is None: generator = np.random # set timesteps self.scheduler.set_timesteps(num_inference_steps) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) num_channels_latents = NUM_LATENT_CHANNELS latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8) latents_dtype = prompt_embeds.dtype if latents is None: latents = generator.randn(*latents_shape).astype(latents_dtype) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # prepare mask and masked_image mask, masked_image = prepare_mask_and_masked_image(image, mask_image, latents_shape[-2:]) mask = mask.astype(latents.dtype) masked_image = masked_image.astype(latents.dtype) masked_image_latents = self.vae_encoder(sample=masked_image)[0] masked_image_latents = 0.18215 * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt mask = mask.repeat(batch_size * num_images_per_prompt, 0) masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] unet_input_channels = NUM_UNET_INPUT_CHANNELS if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels: raise ValueError( "Incorrect configuration settings! The config of `pipeline.unet` expects" f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) # set timesteps self.scheduler.set_timesteps(num_inference_steps) # scale the initial noise by the standard deviation required by the scheduler latents = latents * np.float64(self.scheduler.init_noise_sigma) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latnets in the channel dimension latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = latent_model_input.cpu().numpy() latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ 0 ] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ) latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) # safety_checker does not support batched inputs yet images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py ================================================ import inspect from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTokenizer from ...configuration_utils import FrozenDict from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import deprecate, logging from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? mask = 1 - mask # repaint white, keep black return mask class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): r""" Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] safety_checker: OnnxRuntimeModel feature_extractor: CLIPImageProcessor def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def __call__( self, prompt: Union[str, List[str]], image: Union[np.ndarray, PIL.Image.Image] = None, mask_image: Union[np.ndarray, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[np.random.RandomState] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`nd.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. This is the image whose masked region will be inpainted. mask_image (`nd.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if generator is None: generator = np.random # set timesteps self.scheduler.set_timesteps(num_inference_steps) if isinstance(image, PIL.Image.Image): image = preprocess(image) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = prompt_embeds.dtype image = image.astype(latents_dtype) # encode the init image into latents and scale the latents init_latents = self.vae_encoder(sample=image)[0] init_latents = 0.18215 * init_latents # Expand init_latents for batch_size and num_images_per_prompt init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) init_latents_orig = init_latents # preprocess mask if not isinstance(mask_image, np.ndarray): mask_image = preprocess_mask(mask_image, 8) mask_image = mask_image.astype(latents_dtype) mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) # check sizes if not mask.shape == init_latents.shape: raise ValueError("The mask and image should be the same size!") # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps.numpy()[-init_timestep] timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) # add noise to latents using the timesteps noise = generator.randn(*init_latents.shape).astype(latents_dtype) init_latents = self.scheduler.add_noise( torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) ) init_latents = init_latents.numpy() # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:].numpy() timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ 0 ] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs ).prev_sample latents = latents.numpy() init_latents_proper = self.scheduler.add_noise( torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t])) ) init_latents_proper = init_latents_proper.numpy() latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 image = np.concatenate( [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] ) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) if self.safety_checker is not None: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) # There will throw an error if use safety_checker batchsize>1 images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] ) images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) else: has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py ================================================ from logging import getLogger from typing import Any, Callable, List, Optional, Union import numpy as np import PIL import torch from ...schedulers import DDPMScheduler from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ..pipeline_utils import ImagePipelineOutput from . import StableDiffusionUpscalePipeline logger = getLogger(__name__) NUM_LATENT_CHANNELS = 4 NUM_UNET_INPUT_CHANNELS = 7 ORT_TO_PT_TYPE = { "float16": torch.float16, "float32": torch.float32, } def preprocess(image): if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 32 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): def __init__( self, vae: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: Any, unet: OnnxRuntimeModel, low_res_scheduler: DDPMScheduler, scheduler: Any, max_noise_level: int = 350, ): super().__init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, low_res_scheduler=low_res_scheduler, scheduler=scheduler, safety_checker=None, feature_extractor=None, watermarker=None, max_noise_level=max_noise_level, ) def __call__( self, prompt: Union[str, List[str]], image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. image (`np.ndarray` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. noise_level TODO negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`np.random.RandomState`, *optional*): A np.random.RandomState to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs(prompt, image, noise_level, callback_steps) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_embeddings = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)] # 4. Preprocess image image = preprocess(image) image = image.cpu() # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Add noise to image noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype) image = self.low_res_scheduler.add_noise(image, noise, noise_level) batch_multiplier = 2 if do_classifier_free_guidance else 1 image = np.concatenate([image] * batch_multiplier * num_images_per_prompt) noise_level = np.concatenate([noise_level] * image.shape[0]) # 6. Prepare latent variables height, width = image.shape[2:] latents = self.prepare_latents( batch_size * num_images_per_prompt, NUM_LATENT_CHANNELS, height, width, latents_dtype, device, generator, latents, ) # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS: raise ValueError( "Incorrect configuration settings! The config of `pipeline.unet` expects" f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +" f" `num_channels_image`: {num_channels_image} " f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) timestep_dtype = next( (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)" ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = np.concatenate([latent_model_input, image], axis=1) # timestep to tensor timestep = np.array([t], dtype=timestep_dtype) # predict the noise residual noise_pred = self.unet( sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings, class_labels=noise_level.astype(np.int64), )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs ).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 10. Post-processing image = self.decode_latents(latents.float()) # 11. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) def decode_latents(self, latents): latents = 1 / 0.08333 * latents image = self.vae(latent_sample=latents)[0] image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) return image def _encode_prompt( self, prompt: Union[str, List[str]], device, num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) # no positional arguments to text_encoder prompt_embeds = self.text_encoder( input_ids=text_input_ids.int().to(device), # attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.reshape(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) # if hasattr(uncond_input, "attention_mask"): # attention_mask = uncond_input.attention_mask.to(device) # else: # attention_mask = None uncond_embeddings = self.text_encoder( input_ids=uncond_input.input_ids.int().to(device), # attention_mask=attention_mask, ) uncond_embeddings = uncond_embeddings[0] if do_classifier_free_guidance: seq_len = uncond_embeddings.shape[1] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([uncond_embeddings, prompt_embeds]) return prompt_embeds ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPipeline >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import math import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from torch.nn import functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionAttendAndExcitePipeline >>> pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... ).to("cuda") >>> prompt = "a cat and a frog" >>> # use get_indices function to find out indices of the tokens you want to alter >>> pipe.get_indices(prompt) {0: '<|startoftext|>', 1: 'a', 2: 'cat', 3: 'and', 4: 'a', 5: 'frog', 6: '<|endoftext|>'} >>> token_indices = [2, 5] >>> seed = 6141 >>> generator = torch.Generator("cuda").manual_seed(seed) >>> images = pipe( ... prompt=prompt, ... token_indices=token_indices, ... guidance_scale=7.5, ... generator=generator, ... num_inference_steps=50, ... max_iter_to_alter=25, ... ).images >>> image = images[0] >>> image.save(f"../images/{prompt}_{seed}.png") ``` """ class AttentionStore: @staticmethod def get_empty_store(): return {"down": [], "mid": [], "up": []} def __call__(self, attn, is_cross: bool, place_in_unet: str): if self.cur_att_layer >= 0 and is_cross: if attn.shape[1] == np.prod(self.attn_res): self.step_store[place_in_unet].append(attn) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers: self.cur_att_layer = 0 self.between_steps() def between_steps(self): self.attention_store = self.step_store self.step_store = self.get_empty_store() def get_average_attention(self): average_attention = self.attention_store return average_attention def aggregate_attention(self, from_where: List[str]) -> torch.Tensor: """Aggregates the attention across the different layers and heads at the specified resolution.""" out = [] attention_maps = self.get_average_attention() for location in from_where: for item in attention_maps[location]: cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1]) out.append(cross_maps) out = torch.cat(out, dim=0) out = out.sum(0) / out.shape[0] return out def reset(self): self.cur_att_layer = 0 self.step_store = self.get_empty_store() self.attention_store = {} def __init__(self, attn_res): """ Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion process """ self.num_att_layers = -1 self.cur_att_layer = 0 self.step_store = self.get_empty_store() self.attention_store = {} self.curr_step_index = 0 self.attn_res = attn_res class AttendExciteAttnProcessor: def __init__(self, attnstore, place_in_unet): super().__init__() self.attnstore = attnstore self.place_in_unet = place_in_unet def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) is_cross = encoder_hidden_states is not None encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) # only need to store attention maps during the Attend and Excite process if attention_probs.requires_grad: self.attnstore(attention_probs, is_cross, self.place_in_unet) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, indices, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int) indices_is_list_list_ints = ( isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int) ) if not indices_is_list_ints and not indices_is_list_list_ints: raise TypeError("`indices` must be a list of ints or a list of a list of ints") if indices_is_list_ints: indices_batch_size = 1 elif indices_is_list_list_ints: indices_batch_size = len(indices) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if indices_batch_size != prompt_batch_size: raise ValueError( f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @staticmethod def _compute_max_attention_per_index( attention_maps: torch.Tensor, indices: List[int], ) -> List[torch.Tensor]: """Computes the maximum attention value for each of the tokens we wish to alter.""" attention_for_text = attention_maps[:, :, 1:-1] attention_for_text *= 100 attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1) # Shift indices since we removed the first token indices = [index - 1 for index in indices] # Extract the maximum values max_indices_list = [] for i in indices: image = attention_for_text[:, :, i] smoothing = GaussianSmoothing().to(attention_maps.device) input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect") image = smoothing(input).squeeze(0).squeeze(0) max_indices_list.append(image.max()) return max_indices_list def _aggregate_and_get_max_attention_per_token( self, indices: List[int], ): """Aggregates the attention for each token and computes the max activation value for each token to alter.""" attention_maps = self.attention_store.aggregate_attention( from_where=("up", "down", "mid"), ) max_attention_per_index = self._compute_max_attention_per_index( attention_maps=attention_maps, indices=indices, ) return max_attention_per_index @staticmethod def _compute_loss(max_attention_per_index: List[torch.Tensor]) -> torch.Tensor: """Computes the attend-and-excite loss using the maximum attention value for each token.""" losses = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index] loss = max(losses) return loss @staticmethod def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor: """Update the latent according to the computed loss.""" grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0] latents = latents - step_size * grad_cond return latents def _perform_iterative_refinement_step( self, latents: torch.Tensor, indices: List[int], loss: torch.Tensor, threshold: float, text_embeddings: torch.Tensor, step_size: float, t: int, max_refinement_steps: int = 20, ): """ Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent code according to our loss objective until the given threshold is reached for all tokens. """ iteration = 0 target_loss = max(0, 1.0 - threshold) while loss > target_loss: iteration += 1 latents = latents.clone().detach().requires_grad_(True) self.unet(latents, t, encoder_hidden_states=text_embeddings).sample self.unet.zero_grad() # Get max activation value for each subject token max_attention_per_index = self._aggregate_and_get_max_attention_per_token( indices=indices, ) loss = self._compute_loss(max_attention_per_index) if loss != 0: latents = self._update_latent(latents, loss, step_size) logger.info(f"\t Try {iteration}. loss: {loss}") if iteration >= max_refinement_steps: logger.info(f"\t Exceeded max number of iterations ({max_refinement_steps})! ") break # Run one more time but don't compute gradients and update the latents. # We just need to compute the new loss - the grad update will occur below latents = latents.clone().detach().requires_grad_(True) _ = self.unet(latents, t, encoder_hidden_states=text_embeddings).sample self.unet.zero_grad() # Get max activation value for each subject token max_attention_per_index = self._aggregate_and_get_max_attention_per_token( indices=indices, ) loss = self._compute_loss(max_attention_per_index) logger.info(f"\t Finished with loss of: {loss}") return loss, latents, max_attention_per_index def register_attention_control(self): attn_procs = {} cross_att_count = 0 for name in self.unet.attn_processors.keys(): if name.startswith("mid_block"): place_in_unet = "mid" elif name.startswith("up_blocks"): place_in_unet = "up" elif name.startswith("down_blocks"): place_in_unet = "down" else: continue cross_att_count += 1 attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet) self.unet.set_attn_processor(attn_procs) self.attention_store.num_att_layers = cross_att_count def get_indices(self, prompt: str) -> Dict[str, int]: """Utility function to list the indices of the tokens you wish to alte""" ids = self.tokenizer(prompt).input_ids indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))} return indices @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]], token_indices: Union[List[int], List[List[int]]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, max_iter_to_alter: int = 25, thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8}, scale_factor: int = 20, attn_res: Optional[Tuple[int]] = (16, 16), ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. token_indices (`List[int]`): The token indices to alter with attend-and-excite. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). max_iter_to_alter (`int`, *optional*, defaults to `25`): Number of denoising steps to apply attend-and-excite. The first denoising steps are where the attend-and-excite is applied. I.e. if `max_iter_to_alter` is 25 and there are a total of `30` denoising steps, the first 25 denoising steps will apply attend-and-excite and the last 5 will not apply attend-and-excite. thresholds (`dict`, *optional*, defaults to `{0: 0.05, 10: 0.5, 20: 0.8}`): Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in. scale_factor (`int`, *optional*, default to 20): Scale factor that controls the step size of each Attend and Excite update. attn_res (`tuple`, *optional*, default computed from width and height): The 2D resolution of the semantic attention map. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. :type attention_store: object """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, token_indices, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) if attn_res is None: attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) self.attention_store = AttentionStore(attn_res) self.register_attention_control() # default config for step size from original repo scale_range = np.linspace(1.0, 0.5, len(self.scheduler.timesteps)) step_size = scale_factor * np.sqrt(scale_range) text_embeddings = ( prompt_embeds[batch_size * num_images_per_prompt :] if do_classifier_free_guidance else prompt_embeds ) if isinstance(token_indices[0], int): token_indices = [token_indices] indices = [] for ind in token_indices: indices = indices + [ind] * num_images_per_prompt # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Attend and excite process with torch.enable_grad(): latents = latents.clone().detach().requires_grad_(True) updated_latents = [] for latent, index, text_embedding in zip(latents, indices, text_embeddings): # Forward pass of denoising with text conditioning latent = latent.unsqueeze(0) text_embedding = text_embedding.unsqueeze(0) self.unet( latent, t, encoder_hidden_states=text_embedding, cross_attention_kwargs=cross_attention_kwargs, ).sample self.unet.zero_grad() # Get max activation value for each subject token max_attention_per_index = self._aggregate_and_get_max_attention_per_token( indices=index, ) loss = self._compute_loss(max_attention_per_index=max_attention_per_index) # If this is an iterative refinement step, verify we have reached the desired threshold for all if i in thresholds.keys() and loss > 1.0 - thresholds[i]: loss, latent, max_attention_per_index = self._perform_iterative_refinement_step( latents=latent, indices=index, loss=loss, threshold=thresholds[i], text_embeddings=text_embedding, step_size=step_size[i], t=t, ) # Perform gradient update if i < max_iter_to_alter: if loss != 0: latent = self._update_latent( latents=latent, loss=loss, step_size=step_size[i], ) logger.info(f"Iteration {i} | Loss: {loss:0.4f}") updated_latents.append(latent) latents = torch.cat(updated_latents, dim=0) # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) class GaussianSmoothing(torch.nn.Module): """ Arguments: Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel. dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial). """ # channels=1, kernel_size=kernel_size, sigma=sigma, dim=2 def __init__( self, channels: int = 1, kernel_size: int = 3, sigma: float = 0.5, dim: int = 2, ): super().__init__() if isinstance(kernel_size, int): kernel_size = [kernel_size] * dim if isinstance(sigma, float): sigma = [sigma] * dim # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2)) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer("weight", kernel) self.groups = channels if dim == 1: self.conv = F.conv1d elif dim == 2: self.conv = F.conv2d elif dim == 3: self.conv = F.conv3d else: raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim)) def forward(self, input): """ Arguments: Apply gaussian filter to input. input (torch.Tensor): Input to apply gaussian filter on. Returns: filtered (torch.Tensor): Filtered output. """ return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE: This file is deprecated and will be removed in a future version. # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works from ...utils import deprecate from ..controlnet.multicontrolnet import MultiControlNetModel # noqa: F401 from ..controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline # noqa: F401 deprecate( "stable diffusion controlnet", "0.22.0", "Importing `StableDiffusionControlNetPipeline` or `MultiControlNetModel` from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import StableDiffusionControlNetPipeline` instead.", standard_warn=False, stacklevel=3, ) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, depth_estimator: DPTForDepthEstimation, feature_extractor: DPTFeatureExtractor, ): super().__init__() is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, depth_estimator=depth_estimator, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.depth_estimator]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_guidance, dtype, device): if isinstance(image, PIL.Image.Image): image = [image] else: image = list(image) if isinstance(image[0], PIL.Image.Image): width, height = image[0].size elif isinstance(image[0], np.ndarray): width, height = image[0].shape[:-1] else: height, width = image[0].shape[-2:] if depth_map is None: pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device=device) # The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16. # So we use `torch.autocast` here for half precision inference. context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext() with context_manger: depth_map = self.depth_estimator(pixel_values).predicted_depth else: depth_map = depth_map.to(device=device, dtype=dtype) depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1), size=(height // self.vae_scale_factor, width // self.vae_scale_factor), mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0 depth_map = depth_map.to(dtype) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if depth_map.shape[0] < batch_size: repeat_by = batch_size // depth_map.shape[0] depth_map = depth_map.repeat(repeat_by, 1, 1, 1) depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map return depth_map @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, depth_map: Optional[torch.FloatTensor] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can accept image latents as `image` only if `depth_map` is not `None`. depth_map (`torch.FloatTensor`, *optional*): depth prediction that will be used as additional conditioning for the image generation process. If not defined, it will automatically predicts the depth via `self.depth_estimator`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: ```py >>> import torch >>> import requests >>> from PIL import Image >>> from diffusers import StableDiffusionDepth2ImgPipeline >>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-depth", ... torch_dtype=torch.float16, ... ) >>> pipe.to("cuda") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> init_image = Image.open(requests.get(url, stream=True).raw) >>> prompt = "two tigers" >>> n_propmt = "bad, deformed, ugly, bad anotomy" >>> image = pipe(prompt=prompt, image=init_image, negative_prompt=n_propmt, strength=0.7).images[0] ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs( prompt, strength, callback_steps, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare depth mask depth_mask = self.prepare_depth_map( image, depth_map, batch_size * num_images_per_prompt, do_classifier_free_guidance, prompt_embeds.dtype, device, ) # 5. Preprocess image image = self.image_processor.preprocess(image) # 6. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 7. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py ================================================ # Copyright 2023 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, BaseOutput, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class DiffEditInversionPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: latents (`torch.FloatTensor`) inverted latents tensor images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps, batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ latents: torch.FloatTensor images: Union[List[PIL.Image.Image], np.ndarray] EXAMPLE_DOC_STRING = """ ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionDiffEditPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" >>> init_image = download_image(img_url).resize((768, 768)) >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) >>> pipeline.enable_model_cpu_offload() >>> mask_prompt = "A bowl of fruits" >>> prompt = "A bowl of pears" >>> mask_image = pipe.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt).latents >>> image = pipe(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] ``` """ EXAMPLE_INVERT_DOC_STRING = """ ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionDiffEditPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" >>> init_image = download_image(img_url).resize((768, 768)) >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) >>> pipeline.enable_model_cpu_offload() >>> prompt = "A bowl of fruits" >>> inverted_latents = pipe.invert(image=init_image, prompt=prompt).latents ``` """ def auto_corr_loss(hidden_states, generator=None): reg_loss = 0.0 for i in range(hidden_states.shape[0]): for j in range(hidden_states.shape[1]): noise = hidden_states[i : i + 1, j : j + 1, :, :] while True: roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 if noise.shape[2] <= 8: break noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2) return reg_loss def kl_divergence(hidden_states): return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def preprocess_mask(mask, batch_size: int = 1): if not isinstance(mask, torch.Tensor): # preprocess mask if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray): mask = [mask] if isinstance(mask, list): if isinstance(mask[0], PIL.Image.Image): mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] if isinstance(mask[0], np.ndarray): mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0) mask = torch.from_numpy(mask) elif isinstance(mask[0], torch.Tensor): mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) # Check mask shape if batch_size > 1: if mask.shape[0] == 1: mask = torch.cat([mask] * batch_size) elif mask.shape[0] > 1 and mask.shape[0] != batch_size: raise ValueError( f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} " f"inferred by prompt inputs" ) if mask.shape[1] != 1: raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("`mask_image` should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 return mask class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image inpainting using Stable Diffusion using DiffEdit. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. inverse_scheduler (`[DDIMInverseScheduler]`): A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, inverse_scheduler: DDIMInverseScheduler, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration" " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" " Hub, it would be very nice if you could open a Pull request for the" " `scheduler/scheduler_config.json` file" ) deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["skip_prk_steps"] = True scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, inverse_scheduler=inverse_scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): raise ValueError( f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def check_source_inputs( self, source_prompt=None, source_negative_prompt=None, source_prompt_embeds=None, source_negative_prompt_embeds=None, ): if source_prompt is not None and source_prompt_embeds is not None: raise ValueError( f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}." " Please make sure to only forward one of the two." ) elif source_prompt is None and source_prompt_embeds is None: raise ValueError( "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined." ) elif source_prompt is not None and ( not isinstance(source_prompt, str) and not isinstance(source_prompt, list) ): raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}") if source_negative_prompt is not None and source_negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:" f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two." ) if source_prompt_embeds is not None and source_negative_prompt_embeds is not None: if source_prompt_embeds.shape != source_negative_prompt_embeds.shape: raise ValueError( "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed" f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !=" f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def get_inverse_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) # safety for t_start overflow to prevent empty timsteps slice if t_start == 0: return self.inverse_scheduler.timesteps, num_inference_steps timesteps = self.inverse_scheduler.timesteps[:-t_start] return timesteps, num_inference_steps - t_start # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.StableDiffusionPix2PixZeroPipeline.prepare_image_latents def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) if image.shape[1] == 4: latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] latents = torch.cat(latents, dim=0) else: latents = self.vae.encode(image).latent_dist.sample(generator) latents = self.vae.config.scaling_factor * latents if batch_size != latents.shape[0]: if batch_size % latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_latents_per_image = batch_size // latents.shape[0] latents = torch.cat([latents] * additional_latents_per_image, dim=0) else: raise ValueError( f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." ) else: latents = torch.cat([latents], dim=0) return latents def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): pred_type = self.inverse_scheduler.config.prediction_type alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if pred_type == "epsilon": return model_output elif pred_type == "sample": return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) elif pred_type == "v_prediction": return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" ) @torch.no_grad() def generate_mask( self, image: Union[torch.FloatTensor, PIL.Image.Image] = None, target_prompt: Optional[Union[str, List[str]]] = None, target_negative_prompt: Optional[Union[str, List[str]]] = None, target_prompt_embeds: Optional[torch.FloatTensor] = None, target_negative_prompt_embeds: Optional[torch.FloatTensor] = None, source_prompt: Optional[Union[str, List[str]]] = None, source_negative_prompt: Optional[Union[str, List[str]]] = None, source_prompt_embeds: Optional[torch.FloatTensor] = None, source_negative_prompt_embeds: Optional[torch.FloatTensor] = None, num_maps_per_mask: Optional[int] = 10, mask_encode_strength: Optional[float] = 0.5, mask_thresholding_ratio: Optional[float] = 3.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "np", cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function used to generate a latent mask given a mask prompt, a target prompt, and an image. Args: image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be used for computing the mask. target_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the semantic mask generation. If not defined, one has to pass `prompt_embeds`. instead. target_negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). target_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. target_negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. source_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the semantic mask generation using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If not defined, one has to pass `source_prompt_embeds` or `source_image` instead. source_negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the semantic mask generation away from using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If not defined, one has to pass `source_negative_prompt_embeds` or `source_image` instead. source_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `source_prompt` input argument. source_negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `source_negative_prompt` input argument. num_maps_per_mask (`int`, *optional*, defaults to 10): The number of noise maps sampled to generate the semantic mask using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). mask_encode_strength (`float`, *optional*, defaults to 0.5): Conceptually, the strength of the noise maps sampled to generate the semantic mask using the method in [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance]( https://arxiv.org/pdf/2210.11427.pdf). Must be between 0 and 1. mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before mask binarization. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: `List[PIL.Image.Image]` or `np.array`: `List[PIL.Image.Image]` if `output_type` is `"pil"`, otherwise a `np.array`. When returning a `List[PIL.Image.Image]`, the list will consist of a batch of single-channel binary image with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`, otherwise the `np.array` will have shape `(batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor)`. """ # 1. Check inputs (Provide dummy argument for callback_steps) self.check_inputs( target_prompt, mask_encode_strength, 1, target_negative_prompt, target_prompt_embeds, target_negative_prompt_embeds, ) self.check_source_inputs( source_prompt, source_negative_prompt, source_prompt_embeds, source_negative_prompt_embeds, ) if (num_maps_per_mask is None) or ( num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0) ): raise ValueError( f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type" f" {type(num_maps_per_mask)}." ) if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0: raise ValueError( f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type" f" {type(mask_thresholding_ratio)}." ) # 2. Define call parameters if target_prompt is not None and isinstance(target_prompt, str): batch_size = 1 elif target_prompt is not None and isinstance(target_prompt, list): batch_size = len(target_prompt) else: batch_size = target_prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompts (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) target_prompt_embeds = self._encode_prompt( target_prompt, device, num_maps_per_mask, do_classifier_free_guidance, target_negative_prompt, prompt_embeds=target_prompt_embeds, negative_prompt_embeds=target_negative_prompt_embeds, ) source_prompt_embeds = self._encode_prompt( source_prompt, device, num_maps_per_mask, do_classifier_free_guidance, source_negative_prompt, prompt_embeds=source_prompt_embeds, negative_prompt_embeds=source_negative_prompt_embeds, ) # 4. Preprocess image image = preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength, device) encode_timestep = timesteps[0] # 6. Prepare image latents and add noise with specified strength image_latents = self.prepare_image_latents( image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator ) noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=self.vae.dtype) image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) # 7. Predict the noise residual prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds]) noise_pred = self.unet( latent_model_input, encode_timestep, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample if do_classifier_free_guidance: noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src) noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) else: noise_pred_source, noise_pred_target = noise_pred.chunk(2) # 8. Compute the mask from the absolute difference of predicted noise residuals # TODO: Consider smoothing mask guidance map mask_guidance_map = ( torch.abs(noise_pred_target - noise_pred_source) .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:]) .mean([1, 2]) ) clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1) mask_image = semantic_mask_image.cpu().numpy() # 9. Convert to Numpy array or PIL. if output_type == "pil": mask_image = self.image_processor.numpy_to_pil(mask_image) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() return mask_image @torch.no_grad() @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) def invert( self, prompt: Optional[Union[str, List[str]]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, num_inference_steps: int = 50, inpaint_strength: float = 0.8, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, decode_latents: bool = False, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, lambda_auto_corr: float = 20.0, lambda_kl: float = 20.0, num_reg_steps: int = 0, num_auto_corr_rolls: int = 5, ): r""" Function used to generate inverted latents given a prompt and image. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch to produce the inverted latents, guided by `prompt`. inpaint_strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how far into the noising process to run latent inversion. Must be between 0 and 1. When `strength` is 1, the inversion process will be run for the full number of iterations specified in `num_inference_steps`. `image` will be used as a reference for the inversion process, adding more noise the larger the `strength`. If `strength` is 0, no inpainting will occur. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. decode_latents (`bool`, *optional*, defaults to `False`): Whether or not to decode the inverted latents into a generated image. Setting this argument to `True` will decode all inverted latents for each timestep into a list of generated images. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). lambda_auto_corr (`float`, *optional*, defaults to 20.0): Lambda parameter to control auto correction lambda_kl (`float`, *optional*, defaults to 20.0): Lambda parameter to control Kullback–Leibler divergence output num_reg_steps (`int`, *optional*, defaults to 0): Number of regularization loss steps num_auto_corr_rolls (`int`, *optional*, defaults to 5): Number of auto correction roll steps Examples: Returns: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] if `return_dict` is `True`, otherwise a `tuple`. When returning a tuple, the first element is the inverted latents tensors ordered by increasing noise, and then second is the corresponding decoded images if `decode_latents` is `True`, otherwise `None`. """ # 1. Check inputs self.check_inputs( prompt, inpaint_strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image image = preprocess(image) # 4. Prepare latent variables num_images_per_prompt = 1 latents = self.prepare_image_latents( image, batch_size * num_images_per_prompt, self.vae.dtype, device, generator ) # 5. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 6. Prepare timesteps self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength, device) # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order inverted_latents = [latents.detach().clone()] with self.progress_bar(total=num_inference_steps - 1) as progress_bar: for i, t in enumerate(timesteps[:-1]): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) if num_reg_steps > 0: with torch.enable_grad(): for _ in range(num_reg_steps): if lambda_auto_corr > 0: for _ in range(num_auto_corr_rolls): var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_ac = auto_corr_loss(var_epsilon, generator=generator) l_ac.backward() grad = var.grad.detach() / num_auto_corr_rolls noise_pred = noise_pred - lambda_auto_corr * grad if lambda_kl > 0: var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_kld = kl_divergence(var_epsilon) l_kld.backward() grad = var.grad.detach() noise_pred = noise_pred - lambda_kl * grad noise_pred = noise_pred.detach() # compute the previous noisy sample x_t -> x_t-1 latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample inverted_latents.append(latents.detach().clone()) # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) assert len(inverted_latents) == len(timesteps) latents = torch.stack(list(reversed(inverted_latents)), 1) # 8. Post-processing image = None if decode_latents: image = self.decode_latents(latents.flatten(0, 1).detach()) # 9. Convert to PIL. if decode_latents and output_type == "pil": image = self.image_processor.numpy_to_pil(image) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (latents, image) return DiffEditInversionPipelineOutput(latents=latents, images=image) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, image_latents: torch.FloatTensor = None, inpaint_strength: Optional[float] = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask the generated image. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, 1, H, W)`. image_latents (`PIL.Image.Image` or `torch.FloatTensor`): Partially noised image latents from the inversion process to be used as inputs for image generation. inpaint_strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` is 1, the denoising process will be run on the masked area for the full number of iterations specified in `num_inference_steps`. `image_latents` will be used as a reference for the masked area, adding more noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs( prompt, inpaint_strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) if mask_image is None: raise ValueError( "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts." ) if image_latents is None: raise ValueError( "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images." ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess mask mask_image = preprocess_mask(mask_image, batch_size) latent_height, latent_width = mask_image.shape[-2:] mask_image = torch.cat([mask_image] * num_images_per_prompt) mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) # 6. Preprocess image latents image_latents = preprocess(image_latents) latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) if image_latents.shape[-3:] != latent_shape: raise ValueError( f"Each latent image in `image_latents` must have shape {latent_shape}, " f"but has shape {image_latents.shape[-3:]}" ) if image_latents.ndim == 4: image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) if image_latents.shape[:2] != (batch_size, len(timesteps)): raise ValueError( f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)} timesteps, " f"but has batch size {image_latents.shape[0]} with latent images from {image_latents.shape[1]} timesteps." ) image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1) image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop latents = image_latents[0].detach().clone() num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # mask with inverted latents from appropriate timestep - use original image latent for last step latents = latents * mask_image + image_latents[i] * (1 - mask_image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableDiffusionImageVariationPipeline(DiffusionPipeline): r""" Pipeline to generate variations from an input image using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ # TODO: feature_extractor is required to encode images (if they are in PIL format), # we should give a descriptive message if the pipeline doesn't have one. _optional_components = ["safety_checker"] def __init__( self, vae: AutoencoderKL, image_encoder: CLIPVisionModelWithProjection, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warn( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.unsqueeze(1) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeddings) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): The image or images to guide the image generation. If you provide a tensor, it needs to comply with the configuration of [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) `CLIPImageProcessor` height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) # 2. Define call parameters if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) else: batch_size = image.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input image image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import StableDiffusionImg2ImgPipeline >>> device = "cuda" >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) >>> prompt = "A fantasy landscape, trending on artstation" >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images >>> images[0].save("fantasy_landscape.png") ``` """ def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionImg2ImgPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image to image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the ``image`` and ``1`` for the ``mask``. The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to ``torch.float32`` too. Args: image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. mask (_type_): The mask to apply to the image, i.e. regions to inpaint. It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. Raises: ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not (ot the other way around). Returns: tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ if image is None: raise ValueError("`image` input cannot be undefined.") if mask is None: raise ValueError("`mask_image` input cannot be undefined.") if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) # Batch and add channel dim for single mask if mask.ndim == 2: mask = mask.unsqueeze(0).unsqueeze(0) # Batch single mask or add channel dim if mask.ndim == 3: # Single batched mask, no channel dim or single mask not batched but channel dim if mask.shape[0] == 1: mask = mask.unsqueeze(0) # Batched masks no channel dim else: mask = mask.unsqueeze(1) assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") # Check mask is in [0, 1] if mask.min() < 0 or mask.max() > 1: raise ValueError("Mask should be in [0, 1] range") # Binarize mask mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # Image as float32 image = image.to(dtype=torch.float32) elif isinstance(mask, torch.Tensor): raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): # resize all images w.r.t passed height an width image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 # preprocess mask if isinstance(mask, (PIL.Image.Image, np.ndarray)): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): mask = np.concatenate([m[None, None, :] for m in mask], axis=0) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) masked_image = image * (mask < 0.5) # n.b. ensure backwards compatibility as old function does not return image if return_image: return mask, masked_image, image return mask, masked_image class StableDiffusionInpaintPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). Default text-to-image stable diffusion checkpoints, such as [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) are also compatible with this pipeline, but might be less performant. Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration" " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" " Hub, it would be very nice if you could open a Pull request for the" " `scheduler/scheduler_config.json` file" ) deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["skip_prk_steps"] = True scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 if unet.config.in_channels != 9: logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, height, width, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, image=None, timestep=None, is_strength_max=True, return_noise=False, return_image_latents=False, ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if (image is None or timestep is None) and not is_strength_max: raise ValueError( "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." "However, either the image or the noise timestep has not been provided." ) if return_image_latents or (latents is None and not is_strength_max): image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents else: noise = latents.to(device) latents = noise * self.scheduler.init_noise_sigma outputs = (latents,) if return_noise: outputs += (noise,) if return_image_latents: outputs += (image_latents,) return outputs def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = self.vae.config.scaling_factor * image_latents return image_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision mask = torch.nn.functional.interpolate( mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) if masked_image_latents.shape[0] < batch_size: if not batch_size % masked_image_latents.shape[0] == 0: raise ValueError( "The passed images and the required batch size don't match. Images are supposed to be duplicated" f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." " Make sure the number of images that you pass is divisible by the total requested batch size." ) masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = ( torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents ) # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`PIL.Image.Image`): `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will be masked out with `mask_image` and repainted according to `prompt`. mask_image (`PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. strength (`float`, *optional*, defaults to 1.): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked portion of the reference `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionInpaintPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" >>> init_image = download_image(img_url).resize((512, 512)) >>> mask_image = download_image(mask_url).resize((512, 512)) >>> pipe = StableDiffusionInpaintPipeline.from_pretrained( ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs self.check_inputs( prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device ) # check that number of inference steps is not < 1 - as this doesn't make sense if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 # 5. Preprocess mask and image mask, masked_image, init_image = prepare_mask_and_masked_image( image, mask_image, height, width, return_image=True ) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, do_classifier_free_guidance, ) init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) init_image = self._encode_vae_image(init_image, generator=generator) # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting num_channels_mask = mask.shape[1] num_channels_masked_image = masked_image_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.unet` or your `mask_image` or `image` input." ) elif num_channels_unet != 4: raise ValueError( f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if num_channels_unet == 9: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if num_channels_unet == 4: init_latents_proper = image_latents[:1] init_mask = mask[:1] if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) def preprocess_image(image, batch_size): w, h = image.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) image = torch.from_numpy(image) return 2.0 * image - 1.0 def preprocess_mask(mask, batch_size, scale_factor=8): if not isinstance(mask, torch.FloatTensor): mask = mask.convert("L") w, h = mask.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = np.vstack([mask[None]] * batch_size) mask = 1 - mask # repaint white, keep black mask = torch.from_numpy(mask) return mask else: valid_mask_channel_sizes = [1, 3] # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) if mask.shape[3] in valid_mask_channel_sizes: mask = mask.permute(0, 3, 1, 2) elif mask.shape[1] not in valid_mask_channel_sizes: raise ValueError( f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," f" but received mask of shape {tuple(mask.shape)}" ) # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape mask = mask.mean(dim=1, keepdim=True) h, w = mask.shape[-2:] h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) return mask class StableDiffusionInpaintPipelineLegacy( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() deprecation_message = ( f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality" "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533" "for more information." ) deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False) if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator): image = image.to(device=device, dtype=dtype) init_latent_dist = self.vae.encode(image).latent_dist init_latents = init_latent_dist.sample(generator=generator) init_latents = self.vae.config.scaling_factor * init_latents # Expand init_latents for batch_size and num_images_per_prompt init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) init_latents_orig = init_latents # add noise to latents using the timesteps noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents, init_latents_orig, noise @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, add_predicted_noise: Optional[bool] = False, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, that will be used as the starting point for the process. This is the image whose masked region will be inpainted. mask_image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` is 1, the denoising process will be run on the masked area for the full number of iterations specified in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. num_inference_steps (`int`, *optional*, defaults to 50): The reference number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter will be modulated by `strength`, as explained above. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. add_predicted_noise (`bool`, *optional*, defaults to True): Use predicted noise instead of random noise when constructing noisy versions of the original image in the reverse diffusion process eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image and mask if not isinstance(image, torch.FloatTensor): image = preprocess_image(image, batch_size) mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables # encode the init image into latents and scale the latents latents, init_latents_orig, noise = self.prepare_latents( image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare mask latent mask = mask_image.to(device=device, dtype=latents.dtype) mask = torch.cat([mask] * num_images_per_prompt) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # masking if add_predicted_noise: init_latents_proper = self.scheduler.add_noise( init_latents_orig, noise_pred_uncond, torch.tensor([t]) ) else: init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) latents = (init_latents_proper * mask) + (latents * (1 - mask)) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # use original latents corresponding to unmasked portions of the image latents = (init_latents_orig * mask) + (latents * (1 - mask)) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py ================================================ # Copyright 2023 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, num_inference_steps: int = 100, guidance_scale: float = 7.5, image_guidance_scale: float = 1.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. This pipeline requires a value of at least `1`. image_guidance_scale (`float`, *optional*, defaults to 1.5): Image guidance scale is to push the generated image towards the inital image `image`. Image guidance scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to generate images that are closely linked to the source image `image`, usually at the expense of lower image quality. This pipeline requires a value of at least `1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> import PIL >>> import requests >>> import torch >>> from io import BytesIO >>> from diffusers import StableDiffusionInstructPix2PixPipeline >>> def download_image(url): ... response = requests.get(url) ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" >>> image = download_image(img_url).resize((512, 512)) >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "make the mountains snowy" >>> image = pipe(prompt=prompt, image=image).images[0] ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Check inputs self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) if image is None: raise ValueError("`image` input cannot be undefined.") # 1. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 # check if scheduler is in sigmas space scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") # 2. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 3. Preprocess image image = self.image_processor.preprocess(image) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare Image latents image_latents = self.prepare_image_latents( image, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, do_classifier_free_guidance, generator, ) height, width = image_latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Check that shapes of latents and image match the UNet channels num_channels_image = image_latents.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Expand the latents if we are doing classifier free guidance. # The latents are expanded 3 times because for pix2pix the guidance\ # is applied for both the text and the input image. latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents # concat latents, image_latents in the channel dimension scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual noise_pred = self.unet( scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False )[0] # Hack: # For karras style schedulers the model does classifer free guidance using the # predicted_original_sample instead of the noise_pred. So we need to compute the # predicted_original_sample here if we are using a karras style scheduler. if scheduler_is_in_sigma_space: step_index = (self.scheduler.timesteps == t).nonzero()[0].item() sigma = self.scheduler.sigmas[step_index] noise_pred = latent_model_input - sigma * noise_pred # perform guidance if do_classifier_free_guidance: noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) noise_pred = ( noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_image) + image_guidance_scale * (noise_pred_image - noise_pred_uncond) ) # Hack: # For karras style schedulers the model does classifer free guidance using the # predicted_original_sample instead of the noise_pred. But the scheduler.step function # expects the noise_pred and computes the predicted_original_sample internally. So we # need to overwrite the noise_pred here such that the value of the computed # predicted_original_sample is correct. if scheduler_is_in_sigma_space: noise_pred = (noise_pred - latents) / (-sigma) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_ prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds] prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_image_latents( self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: image_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.mode() if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if do_classifier_free_guidance: uncond_image_latents = torch.zeros_like(image_latents) image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) return image_latents ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import warnings from typing import Callable, List, Optional, Union import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.sampling import get_sigmas_karras from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...pipelines import DiffusionPipeline from ...schedulers import LMSDiscreteScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class ModelWrapper: def __init__(self, model, alphas_cumprod): self.model = model self.alphas_cumprod = alphas_cumprod def apply_model(self, *args, **kwargs): if len(args) == 3: encoder_hidden_states = args[-1] args = args[:2] if kwargs.get("cond", None) is not None: encoder_hidden_states = kwargs.pop("cond") return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) This is an experimental pipeline and is likely to change in the future. Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker: bool = True, ): super().__init__() logger.info( f"{self.__class__} is an experimntal pipeline and is likely to change in the future. We recommend to use" " this pipeline for fast experimentation / iteration if needed, but advice to rely on existing pipelines" " as defined in https://huggingface.co/docs/diffusers/api/schedulers#implemented-schedulers for" " production settings." ) # get correct sigmas from LMS scheduler = LMSDiscreteScheduler.from_config(scheduler.config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) model = ModelWrapper(unet, scheduler.alphas_cumprod) if scheduler.config.prediction_type == "v_prediction": self.k_diffusion_model = CompVisVDenoiser(model) else: self.k_diffusion_model = CompVisDenoiser(model) def set_scheduler(self, scheduler_type: str): library = importlib.import_module("k_diffusion") sampling = getattr(library, "sampling") self.sampler = getattr(sampling, scheduler_type) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, use_karras_sigmas: Optional[bool] = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M Karras`. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = True if guidance_scale <= 1.0: raise ValueError("has to use guidance_scale") # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device) # 5. Prepare sigmas if use_karras_sigmas: sigma_min: float = self.k_diffusion_model.sigmas[0].item() sigma_max: float = self.k_diffusion_model.sigmas[-1].item() sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max) sigmas = sigmas.to(device) else: sigmas = self.scheduler.sigmas sigmas = sigmas.to(prompt_embeds.dtype) # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) latents = latents * sigmas[0] self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) # 7. Define model function def model_fn(x, t): latent_model_input = torch.cat([x] * 2) t = torch.cat([t] * 2) noise_pred = self.k_diffusion_model(latent_model_input, t, cond=prompt_embeds) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) return noise_pred # 8. Run k-diffusion solver latents = self.sampler(model_fn, latents, sigmas) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionLatentUpscalePipeline(DiffusionPipeline): r""" Pipeline to upscale the resolution of Stable Diffusion output images by a factor of 2. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/main/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`EulerDiscreteScheduler`]. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: EulerDiscreteScheduler, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `list(int)`): prompt to be encoded device: (`torch.device`): torch device do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_length=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_encoder_out = self.text_encoder( text_input_ids.to(device), output_hidden_states=True, ) text_embeddings = text_encoder_out.hidden_states[-1] text_pooler_out = text_encoder_out.pooler_output # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_length=True, return_tensors="pt", ) uncond_encoder_out = self.text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) uncond_embeddings = uncond_encoder_out.hidden_states[-1] uncond_pooler_out = uncond_encoder_out.pooler_output # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out]) return text_embeddings, text_pooler_out # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs(self, prompt, image, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" ) # verify batch size of prompt and image are same if image is a list or tensor if isinstance(image, list) or isinstance(image, torch.Tensor): if isinstance(prompt, str): batch_size = 1 else: batch_size = len(prompt) if isinstance(image, list): image_batch_size = len(image) else: image_batch_size = image.shape[0] if image.ndim == 4 else 1 if batch_size != image_batch_size: raise ValueError( f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." " Please make sure that passed `prompt` matches the batch size of `image`." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height, width) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image upscaling. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an image representation and encoded using this pipeline's `vae` encoder. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline >>> import torch >>> pipeline = StableDiffusionPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... ) >>> pipeline.to("cuda") >>> model_id = "stabilityai/sd-x2-latent-upscaler" >>> upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16) >>> upscaler.to("cuda") >>> prompt = "a photo of an astronaut high resolution, unreal engine, ultra realistic" >>> generator = torch.manual_seed(33) >>> low_res_latents = pipeline(prompt, generator=generator, output_type="latent").images >>> with torch.no_grad(): ... image = pipeline.decode_latents(low_res_latents) >>> image = pipeline.numpy_to_pil(image)[0] >>> image.save("../images/a1.png") >>> upscaled_image = upscaler( ... prompt=prompt, ... image=low_res_latents, ... num_inference_steps=20, ... guidance_scale=0, ... generator=generator, ... ).images[0] >>> upscaled_image.save("../images/a2.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs(prompt, image, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if guidance_scale == 0: prompt = [""] * batch_size # 3. Encode input prompt text_embeddings, text_pooler_out = self._encode_prompt( prompt, device, do_classifier_free_guidance, negative_prompt ) # 4. Preprocess image image = self.image_processor.preprocess(image) image = image.to(dtype=text_embeddings.dtype, device=device) if image.shape[1] == 3: # encode image if not in latent-space yet image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps batch_multiplier = 2 if do_classifier_free_guidance else 1 image = image[None, :] if image.ndim == 3 else image image = torch.cat([image] * batch_multiplier) # 5. Add noise to image (set to be 0): # (see below notes from the author): # "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly just seems to make it match the input less, so it's turned off by default." noise_level = torch.tensor([0.0], dtype=torch.float32, device=device) noise_level = torch.cat([noise_level] * image.shape[0]) inv_noise_level = (noise_level**2 + 1) ** (-0.5) image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None] image_cond = image_cond.to(text_embeddings.dtype) noise_level_embed = torch.cat( [ torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device), ], dim=1, ) timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1) # 6. Prepare latent variables height, width = image.shape[2:] num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size, num_channels_latents, height * 2, # 2x upscale width * 2, text_embeddings.dtype, device, generator, latents, ) # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 9. Denoising loop num_warmup_steps = 0 with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): sigma = self.scheduler.sigmas[i] # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) scaled_model_input = torch.cat([scaled_model_input, image_cond], dim=1) # preconditioning parameter based on Karras et al. (2022) (table 1) timestep = torch.log(sigma) * 0.25 noise_pred = self.unet( scaled_model_input, timestep, encoder_hidden_states=text_embeddings, timestep_cond=timestep_condition, ).sample # in original repo, the output contains a variance channel that's not used noise_pred = noise_pred[:, :-1] # apply preconditioning, based on table 1 in Karras et al. (2022) inv_sigma = 1 / (sigma**2 + 1) noise_pred = inv_sigma * latent_model_input + self.scheduler.scale_model_input(sigma, t) * noise_pred # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py ================================================ # Copyright 2023 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessorLDM3D from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BaseOutput, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPipeline >>> pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d") >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> output = pipe(prompt) >>> rgb_image, depth_image = output.rgb, output.depth ``` """ @dataclass class LDM3DPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. """ rgb: Union[List[PIL.Image.Image], np.ndarray] depth: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] class StableDiffusionLDM3DPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image and 3d generation using LDM3D. LDM3D: Latent Diffusion Model for 3D: https://arxiv.org/abs/2305.10853 This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode rgb and depth images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded rgb and depth latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) rgb_feature_extractor_input = feature_extractor_input[0] safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 49, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return ((rgb, depth), has_nsfw_concept) return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py ================================================ # Copyright 2023 TIME Authors and The HuggingFace Team. All rights reserved." # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import PNDMScheduler from ...schedulers.scheduling_utils import SchedulerMixin from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name AUGS_CONST = ["A photo of ", "An image of ", "A picture of "] EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionModelEditingPipeline >>> model_ckpt = "CompVis/stable-diffusion-v1-4" >>> pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt) >>> pipe = pipe.to("cuda") >>> source_prompt = "A pack of roses" >>> destination_prompt = "A pack of blue roses" >>> pipe.edit_model(source_prompt, destination_prompt) >>> prompt = "A field of roses" >>> image = pipe(prompt).images[0] ``` """ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models". This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. with_to_k ([`bool`]): Whether to edit the key projection matrices along wiht the value projection matrices. with_augs ([`list`]): Textual augmentations to apply while editing the text-to-image model. Set to [] for no augmentations. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, with_to_k: bool = True, with_augs: list = AUGS_CONST, ): super().__init__() if isinstance(scheduler, PNDMScheduler): logger.error("PNDMScheduler for this pipeline is currently not supported.") if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.with_to_k = with_to_k self.with_augs = with_augs # get cross-attention layers ca_layers = [] def append_ca(net_): if net_.__class__.__name__ == "CrossAttention": ca_layers.append(net_) elif hasattr(net_, "children"): for net__ in net_.children(): append_ca(net__) # recursively find all cross-attention layers in unet for net in self.unet.named_children(): if "down" in net[0]: append_ca(net[1]) elif "up" in net[0]: append_ca(net[1]) elif "mid" in net[0]: append_ca(net[1]) # get projection matrices self.ca_clip_layers = [l for l in ca_layers if l.to_v.in_features == 768] self.projection_matrices = [l.to_v for l in self.ca_clip_layers] self.og_matrices = [copy.deepcopy(l.to_v) for l in self.ca_clip_layers] if self.with_to_k: self.projection_matrices = self.projection_matrices + [l.to_k for l in self.ca_clip_layers] self.og_matrices = self.og_matrices + [copy.deepcopy(l.to_k) for l in self.ca_clip_layers] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def edit_model( self, source_prompt: str, destination_prompt: str, lamb: float = 0.1, restart_params: bool = True, ): r""" Apply model editing via closed-form solution (see Eq. 5 in the TIME paper https://arxiv.org/abs/2303.08084) Args: source_prompt (`str`): The source prompt containing the concept to be edited. destination_prompt (`str`): The destination prompt. Must contain all words from source_prompt with additional ones to specify the target edit. lamb (`float`, *optional*, defaults to 0.1): The lambda parameter specifying the regularization intesity. Smaller values increase the editing power. restart_params (`bool`, *optional*, defaults to True): Restart the model parameters to their pre-trained version before editing. This is done to avoid edit compounding. When it is False, edits accumulate. """ # restart LDM parameters if restart_params: num_ca_clip_layers = len(self.ca_clip_layers) for idx_, l in enumerate(self.ca_clip_layers): l.to_v = copy.deepcopy(self.og_matrices[idx_]) self.projection_matrices[idx_] = l.to_v if self.with_to_k: l.to_k = copy.deepcopy(self.og_matrices[num_ca_clip_layers + idx_]) self.projection_matrices[num_ca_clip_layers + idx_] = l.to_k # set up sentences old_texts = [source_prompt] new_texts = [destination_prompt] # add augmentations base = old_texts[0] if old_texts[0][0:1] != "A" else "a" + old_texts[0][1:] for aug in self.with_augs: old_texts.append(aug + base) base = new_texts[0] if new_texts[0][0:1] != "A" else "a" + new_texts[0][1:] for aug in self.with_augs: new_texts.append(aug + base) # prepare input k* and v* old_embs, new_embs = [], [] for old_text, new_text in zip(old_texts, new_texts): text_input = self.tokenizer( [old_text, new_text], padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] old_emb, new_emb = text_embeddings old_embs.append(old_emb) new_embs.append(new_emb) # identify corresponding destinations for each token in old_emb idxs_replaces = [] for old_text, new_text in zip(old_texts, new_texts): tokens_a = self.tokenizer(old_text).input_ids tokens_b = self.tokenizer(new_text).input_ids tokens_a = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_a] tokens_b = [self.tokenizer.encode("a ")[1] if self.tokenizer.decode(t) == "an" else t for t in tokens_b] num_orig_tokens = len(tokens_a) idxs_replace = [] j = 0 for i in range(num_orig_tokens): curr_token = tokens_a[i] while tokens_b[j] != curr_token: j += 1 idxs_replace.append(j) j += 1 while j < 77: idxs_replace.append(j) j += 1 while len(idxs_replace) < 77: idxs_replace.append(76) idxs_replaces.append(idxs_replace) # prepare batch: for each pair of setences, old context and new values contexts, valuess = [], [] for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces): context = old_emb.detach() values = [] with torch.no_grad(): for layer in self.projection_matrices: values.append(layer(new_emb[idxs_replace]).detach()) contexts.append(context) valuess.append(values) # edit the model for layer_num in range(len(self.projection_matrices)): # mat1 = \lambda W + \sum{v k^T} mat1 = lamb * self.projection_matrices[layer_num].weight # mat2 = \lambda I + \sum{k k^T} mat2 = lamb * torch.eye( self.projection_matrices[layer_num].weight.shape[1], device=self.projection_matrices[layer_num].weight.device, ) # aggregate sums for mat1, mat2 for context, values in zip(contexts, valuess): context_vector = context.reshape(context.shape[0], context.shape[1], 1) context_vector_T = context.reshape(context.shape[0], 1, context.shape[1]) value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1) for_mat1 = (value_vector @ context_vector_T).sum(dim=0) for_mat2 = (context_vector @ context_vector_T).sum(dim=0) mat1 += for_mat1 mat2 += for_mat2 # update projection matrix self.projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2)) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py ================================================ # Copyright 2023 MultiDiffusion Authors and The HuggingFace Team. All rights reserved." # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler >>> model_ckpt = "stabilityai/stable-diffusion-2-base" >>> scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler") >>> pipe = StableDiffusionPanoramaPipeline.from_pretrained( ... model_ckpt, scheduler=scheduler, torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of the dolomites" >>> image = pipe(prompt).images[0] ``` """ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation". This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.). To generate panorama-like images, be sure to pass the `width` parameter accordingly when using the pipeline. Our recommendation for the `width` value is 2048. This is the default value of the `width` parameter for this pipeline. Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. The original work on Multi Diffsion used the [`DDIMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) # if panorama's height/width < window_size, num_blocks of height/width should return 1 panorama_height /= 8 panorama_width /= 8 num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1 total_num_blocks = int(num_blocks_height * num_blocks_width) views = [] for i in range(total_num_blocks): h_start = int((i // num_blocks_width) * stride) h_end = h_start + window_size w_start = int((i % num_blocks_width) * stride) w_end = w_start + window_size views.append((h_start, h_end, w_start, w_end)) return views @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = 512, width: Optional[int] = 2048, num_inference_steps: int = 50, guidance_scale: float = 7.5, view_batch_size: int = 1, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to 512: The height in pixels of the generated image. width (`int`, *optional*, defaults to 2048): The width in pixels of the generated image. The width is kept to a high number because the pipeline is supposed to be used for generating panorama-like images. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. view_batch_size (`int`, *optional*, defaults to 1): The batch size to denoise splited views. For some GPUs with high performance, higher view batch size can speedup the generation and increase the VRAM usage. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Define panorama grid and initialize views for synthesis. # prepare batch grid views = self.get_views(height, width) views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch) count = torch.zeros_like(latents) value = torch.zeros_like(latents) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop # Each denoising step also includes refinement of the latents with respect to the # views. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): count.zero_() value.zero_() # generate views # Here, we iterate through different spatial crops of the latents and denoise them. These # denoised (latent) crops are then averaged to produce the final latent # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 # Batch views denoise for j, batch_view in enumerate(views_batch): vb_size = len(batch_view) # get the latents corresponding to the current view coordinates latents_for_view = torch.cat( [latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view] ) # rematch block's scheduler status self.scheduler.__dict__.update(views_scheduler_status[j]) # expand the latents if we are doing classifier free guidance latent_model_input = ( latents_for_view.repeat_interleave(2, dim=0) if do_classifier_free_guidance else latents_for_view ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # repeat prompt_embeds for batch prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds_input, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents_denoised_batch = self.scheduler.step( noise_pred, t, latents_for_view, **extra_step_kwargs ).prev_sample # save views scheduler status after sample views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) # extract value from batch for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( latents_denoised_batch.chunk(vb_size), batch_view ): value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = torch.where(count > 0, value / count, value) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py ================================================ # Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import DDPMParallelScheduler >>> from diffusers import StableDiffusionParadigmsPipeline >>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler") >>> pipe = StableDiffusionParadigmsPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> ngpu, batch_per_device = torch.cuda.device_count(), 5 >>> pipe.wrapped_unet = torch.nn.DataParallel(pipe.unet, device_ids=[d for d in range(ngpu)]) >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, parallel=ngpu * batch_per_device, num_inference_steps=1000).images[0] ``` """ class StableDiffusionParadigmsPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): r""" Parallelized version of StableDiffusionPipeline, based on the paper https://arxiv.org/abs/2305.16317 This pipeline parallelizes the denoising steps to generate a single image faster (more akin to model parallelism). Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # attribute to wrap the unet with torch.nn.DataParallel when running multiple denoising steps on multiple GPUs self.wrapped_unet = self.unet # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _cumsum(self, input, dim, debug=False): if debug: # cumsum_cuda_kernel does not have a deterministic implementation # so perform cumsum on cpu for debugging purposes return torch.cumsum(input.cpu().float(), dim=dim).to(input.device) else: return torch.cumsum(input, dim=dim) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, parallel: int = 10, tolerance: float = 0.1, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, debug: bool = False, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. parallel (`int`, *optional*, defaults to 10): The batch size to use when doing parallel sampling. More parallelism may lead to faster inference but requires higher memory usage and also can require more total FLOPs. tolerance (`float`, *optional*, defaults to 0.1): The error tolerance for determining when to slide the batch window forward for parallel sampling. Lower tolerance usually leads to less/no degradation. Higher tolerance is faster but can risk degradation of sample quality. The tolerance is specified as a ratio of the scheduler's noise magnitude. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). debug (`bool`, *optional*, defaults to `False`): Whether or not to run in debug mode. In debug mode, torch.cumsum is evaluated using the CPU. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs.pop("generator", None) # # 7. Denoising loop scheduler = self.scheduler parallel = min(parallel, len(scheduler.timesteps)) begin_idx = 0 end_idx = parallel latents_time_evolution_buffer = torch.stack([latents] * (len(scheduler.timesteps) + 1)) # We must make sure the noise of stochastic schedulers such as DDPM is sampled only once per timestep. # Sampling inside the parallel denoising loop will mess this up, so we pre-sample the noise vectors outside the denoising loop. noise_array = torch.zeros_like(latents_time_evolution_buffer) for j in range(len(scheduler.timesteps)): base_noise = randn_tensor( shape=latents.shape, generator=generator, device=latents.device, dtype=prompt_embeds.dtype ) noise = (self.scheduler._get_variance(scheduler.timesteps[j]) ** 0.5) * base_noise noise_array[j] = noise.clone() # We specify the error tolerance as a ratio of the scheduler's noise magnitude. We similarly compute the error tolerance # outside of the denoising loop to avoid recomputing it at every step. # We will be dividing the norm of the noise, so we store its inverse here to avoid a division at every step. inverse_variance_norm = 1.0 / torch.tensor( [scheduler._get_variance(scheduler.timesteps[j]) for j in range(len(scheduler.timesteps))] + [0] ).to(noise_array.device) latent_dim = noise_array[0, 0].numel() inverse_variance_norm = inverse_variance_norm[:, None] / latent_dim scaled_tolerance = tolerance**2 with self.progress_bar(total=num_inference_steps) as progress_bar: steps = 0 while begin_idx < len(scheduler.timesteps): # these have shape (parallel_dim, 2*batch_size, ...) # parallel_len is at most parallel, but could be less if we are at the end of the timesteps # we are processing batch window of timesteps spanning [begin_idx, end_idx) parallel_len = end_idx - begin_idx block_prompt_embeds = torch.stack([prompt_embeds] * parallel_len) block_latents = latents_time_evolution_buffer[begin_idx:end_idx] block_t = scheduler.timesteps[begin_idx:end_idx, None].repeat(1, batch_size * num_images_per_prompt) t_vec = block_t if do_classifier_free_guidance: t_vec = t_vec.repeat(1, 2) # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([block_latents] * 2, dim=1) if do_classifier_free_guidance else block_latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t_vec) # if parallel_len is small, no need to use multiple GPUs net = self.wrapped_unet if parallel_len > 3 else self.unet # predict the noise residual, shape is now [parallel_len * 2 * batch_size * num_images_per_prompt, ...] model_output = net( latent_model_input.flatten(0, 1), t_vec.flatten(0, 1), encoder_hidden_states=block_prompt_embeds.flatten(0, 1), cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] per_latent_shape = model_output.shape[1:] if do_classifier_free_guidance: model_output = model_output.reshape( parallel_len, 2, batch_size * num_images_per_prompt, *per_latent_shape ) noise_pred_uncond, noise_pred_text = model_output[:, 0], model_output[:, 1] model_output = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) model_output = model_output.reshape( parallel_len * batch_size * num_images_per_prompt, *per_latent_shape ) block_latents_denoise = scheduler.batch_step_no_noise( model_output=model_output, timesteps=block_t.flatten(0, 1), sample=block_latents.flatten(0, 1), **extra_step_kwargs, ).reshape(block_latents.shape) # back to shape (parallel_dim, batch_size, ...) # now we want to add the pre-sampled noise # parallel sampling algorithm requires computing the cumulative drift from the beginning # of the window, so we need to compute cumulative sum of the deltas and the pre-sampled noises. delta = block_latents_denoise - block_latents cumulative_delta = self._cumsum(delta, dim=0, debug=debug) cumulative_noise = self._cumsum(noise_array[begin_idx:end_idx], dim=0, debug=debug) # if we are using an ODE-like scheduler (like DDIM), we don't want to add noise if scheduler._is_ode_scheduler: cumulative_noise = 0 block_latents_new = ( latents_time_evolution_buffer[begin_idx][None,] + cumulative_delta + cumulative_noise ) cur_error = torch.linalg.norm( (block_latents_new - latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1]).reshape( parallel_len, batch_size * num_images_per_prompt, -1 ), dim=-1, ).pow(2) error_ratio = cur_error * inverse_variance_norm[begin_idx + 1 : end_idx + 1] # find the first index of the vector error_ratio that is greater than error tolerance # we can shift the window for the next iteration up to this index error_ratio = torch.nn.functional.pad( error_ratio, (0, 0, 0, 1), value=1e9 ) # handle the case when everything is below ratio, by padding the end of parallel_len dimension any_error_at_time = torch.max(error_ratio > scaled_tolerance, dim=1).values.int() ind = torch.argmax(any_error_at_time).item() # compute the new begin and end idxs for the window new_begin_idx = begin_idx + min(1 + ind, parallel) new_end_idx = min(new_begin_idx + parallel, len(scheduler.timesteps)) # store the computed latents for the current window in the global buffer latents_time_evolution_buffer[begin_idx + 1 : end_idx + 1] = block_latents_new # initialize the new sliding window latents with the end of the current window, # should be better than random initialization latents_time_evolution_buffer[end_idx : new_end_idx + 1] = latents_time_evolution_buffer[end_idx][ None, ] steps += 1 progress_bar.update(new_begin_idx - begin_idx) if callback is not None and steps % callback_steps == 0: callback(begin_idx, block_t[begin_idx], latents_time_evolution_buffer[begin_idx]) begin_idx = new_begin_idx end_idx = new_end_idx latents = latents_time_evolution_buffer[-1] if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py ================================================ # Copyright 2023 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from transformers import ( BlipForConditionalGeneration, BlipProcessor, CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, ) from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from ...utils import ( PIL_INTERPOLATION, BaseOutput, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): """ Output class for Stable Diffusion pipelines. Args: latents (`torch.FloatTensor`) inverted latents tensor images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ latents: torch.FloatTensor images: Union[List[PIL.Image.Image], np.ndarray] EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline >>> def download(embedding_url, local_filepath): ... r = requests.get(embedding_url) ... with open(local_filepath, "wb") as f: ... f.write(r.content) >>> model_ckpt = "CompVis/stable-diffusion-v1-4" >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.to("cuda") >>> prompt = "a high resolution painting of a cat in the style of van gough" >>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt" >>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt" >>> for url in [source_emb_url, target_emb_url]: ... download(url, url.split("/")[-1]) >>> src_embeds = torch.load(source_emb_url.split("/")[-1]) >>> target_embeds = torch.load(target_emb_url.split("/")[-1]) >>> images = pipeline( ... prompt, ... source_embeds=src_embeds, ... target_embeds=target_embeds, ... num_inference_steps=50, ... cross_attention_guidance_amount=0.15, ... ).images >>> images[0].save("edited_image_dog.png") ``` """ EXAMPLE_INVERT_DOC_STRING = """ Examples: ```py >>> import torch >>> from transformers import BlipForConditionalGeneration, BlipProcessor >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline >>> import requests >>> from PIL import Image >>> captioner_id = "Salesforce/blip-image-captioning-base" >>> processor = BlipProcessor.from_pretrained(captioner_id) >>> model = BlipForConditionalGeneration.from_pretrained( ... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True ... ) >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4" >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( ... sd_model_ckpt, ... caption_generator=model, ... caption_processor=processor, ... torch_dtype=torch.float16, ... safety_checker=None, ... ) >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) >>> pipeline.enable_model_cpu_offload() >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) >>> # generate caption >>> caption = pipeline.generate_caption(raw_image) >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii" >>> inv_latents = pipeline.invert(caption, image=raw_image).latents >>> # we need to generate source and target embeds >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] >>> source_embeds = pipeline.get_embeds(source_prompts) >>> target_embeds = pipeline.get_embeds(target_prompts) >>> # the latents can then be used to edit a real image >>> # when using Stable Diffusion 2 or other models that use v-prediction >>> # set `cross_attention_guidance_amount` to 0.01 or less to avoid input latent gradient explosion >>> image = pipeline( ... caption, ... source_embeds=source_embeds, ... target_embeds=target_embeds, ... num_inference_steps=50, ... cross_attention_guidance_amount=0.15, ... generator=generator, ... latents=inv_latents, ... negative_prompt=caption, ... ).images[0] >>> image.save("edited_image.png") ``` """ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image def prepare_unet(unet: UNet2DConditionModel): """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" pix2pix_zero_attn_procs = {} for name in unet.attn_processors.keys(): module_name = name.replace(".processor", "") module = unet.get_submodule(module_name) if "attn2" in name: pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=True) module.requires_grad_(True) else: pix2pix_zero_attn_procs[name] = Pix2PixZeroAttnProcessor(is_pix2pix_zero=False) module.requires_grad_(False) unet.set_attn_processor(pix2pix_zero_attn_procs) return unet class Pix2PixZeroL2Loss: def __init__(self): self.loss = 0.0 def compute_loss(self, predictions, targets): self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0) class Pix2PixZeroAttnProcessor: """An attention processor class to store the attention weights. In Pix2Pix Zero, it happens during computations in the cross-attention blocks.""" def __init__(self, is_pix2pix_zero=False): self.is_pix2pix_zero = is_pix2pix_zero if self.is_pix2pix_zero: self.reference_cross_attn_map = {} def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, timestep=None, loss=None, ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) if self.is_pix2pix_zero and timestep is not None: # new bookkeeping to save the attention weights. if loss is None: self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu() # compute loss elif loss is not None: prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item()) loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device)) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): r""" Pipeline for pixel-levl image editing using Pix2Pix Zero. Based on Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. requires_safety_checker (bool): Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the pipeline publicly. """ _optional_components = [ "safety_checker", "feature_extractor", "caption_generator", "caption_processor", "inverse_scheduler", ] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], feature_extractor: CLIPImageProcessor, safety_checker: StableDiffusionSafetyChecker, inverse_scheduler: DDIMInverseScheduler, caption_generator: BlipForConditionalGeneration, caption_processor: BlipProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, caption_processor=caption_processor, caption_generator=caption_generator, inverse_scheduler=inverse_scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") hook = None for cpu_offloaded_model in [self.vae, self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, source_embeds, target_embeds, callback_steps, prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if source_embeds is None and target_embeds is None: raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.") if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def generate_caption(self, images): """Generates caption for a given image.""" text = "a photography of" prev_device = self.caption_generator.device device = self._execution_device inputs = self.caption_processor(images, text, return_tensors="pt").to( device=device, dtype=self.caption_generator.dtype ) self.caption_generator.to(device) outputs = self.caption_generator.generate(**inputs, max_new_tokens=128) # offload caption generator self.caption_generator.to(prev_device) caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] return caption def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor): """Constructs the edit direction to steer the image generation process semantically.""" return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0) @torch.no_grad() def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.FloatTensor: num_prompts = len(prompt) embeds = [] for i in range(0, num_prompts, batch_size): prompt_slice = prompt[i : i + batch_size] input_ids = self.tokenizer( prompt_slice, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ).input_ids input_ids = input_ids.to(self.text_encoder.device) embeds.append(self.text_encoder(input_ids)[0]) return torch.cat(embeds, dim=0).mean(0)[None] def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) if image.shape[1] == 4: latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] latents = torch.cat(latents, dim=0) else: latents = self.vae.encode(image).latent_dist.sample(generator) latents = self.vae.config.scaling_factor * latents if batch_size != latents.shape[0]: if batch_size % latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_latents_per_image = batch_size // latents.shape[0] latents = torch.cat([latents] * additional_latents_per_image, dim=0) else: raise ValueError( f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." ) else: latents = torch.cat([latents], dim=0) return latents def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): pred_type = self.inverse_scheduler.config.prediction_type alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if pred_type == "epsilon": return model_output elif pred_type == "sample": return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) elif pred_type == "v_prediction": return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" ) def auto_corr_loss(self, hidden_states, generator=None): reg_loss = 0.0 for i in range(hidden_states.shape[0]): for j in range(hidden_states.shape[1]): noise = hidden_states[i : i + 1, j : j + 1, :, :] while True: roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) return reg_loss def kl_divergence(self, hidden_states): mean = hidden_states.mean() var = hidden_states.var() return var + mean**2 - 1 - torch.log(var + 1e-7) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Optional[Union[str, List[str]]] = None, source_embeds: torch.Tensor = None, target_embeds: torch.Tensor = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, cross_attention_guidance_amount: float = 0.1, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. source_embeds (`torch.Tensor`): Source concept embeddings. Generation of the embeddings as per the [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. target_embeds (`torch.Tensor`): Target concept embeddings. Generation of the embeddings as per the [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. cross_attention_guidance_amount (`float`, defaults to 0.1): Amount of guidance needed from the reference cross-attention maps. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Define the spatial resolutions. height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, source_embeds, target_embeds, callback_steps, prompt_embeds, ) # 3. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Generate the inverted noise from the input image or any other image # generated from the input prompt. num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) latents_init = latents.clone() # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Rejig the UNet so that we can obtain the cross-attenion maps and # use them for guiding the subsequent image generation. self.unet = prepare_unet(self.unet) # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs={"timestep": t}, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Compute the edit directions. edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device) # 9. Edit the prompt embeddings as per the edit directions discovered. prompt_embeds_edit = prompt_embeds.clone() prompt_embeds_edit[1:2] += edit_direction # 10. Second denoising loop to generate the edited image. latents = latents_init num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # we want to learn the latent such that it steers the generation # process towards the edited direction, so make the make initial # noise learnable x_in = latent_model_input.detach().clone() x_in.requires_grad = True # optimizer opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount) with torch.enable_grad(): # initialize loss loss = Pix2PixZeroL2Loss() # predict the noise residual noise_pred = self.unet( x_in, t, encoder_hidden_states=prompt_embeds_edit.detach(), cross_attention_kwargs={"timestep": t, "loss": loss}, ).sample loss.loss.backward(retain_graph=False) opt.step() # recompute the noise noise_pred = self.unet( x_in.detach(), t, encoder_hidden_states=prompt_embeds_edit, cross_attention_kwargs={"timestep": None}, ).sample latents = x_in.detach().chunk(2)[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) @torch.no_grad() @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) def invert( self, prompt: Optional[str] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, num_inference_steps: int = 50, guidance_scale: float = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, cross_attention_guidance_amount: float = 0.1, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, lambda_auto_corr: float = 20.0, lambda_kl: float = 20.0, num_reg_steps: int = 5, num_auto_corr_rolls: int = 5, ): r""" Function used to generate inverted latents given a prompt and image. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be used for conditioning. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 1): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. cross_attention_guidance_amount (`float`, defaults to 0.1): Amount of guidance needed from the reference cross-attention maps. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. lambda_auto_corr (`float`, *optional*, defaults to 20.0): Lambda parameter to control auto correction lambda_kl (`float`, *optional*, defaults to 20.0): Lambda parameter to control Kullback–Leibler divergence output num_reg_steps (`int`, *optional*, defaults to 5): Number of regularization loss steps num_auto_corr_rolls (`int`, *optional*, defaults to 5): Number of auto correction roll steps Examples: Returns: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted latents tensor and then second is the corresponding decoded image. """ # 1. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if cross_attention_kwargs is None: cross_attention_kwargs = {} device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image image = self.image_processor.preprocess(image) # 4. Prepare latent variables latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) # 5. Encode input prompt num_images_per_prompt = 1 prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, prompt_embeds=prompt_embeds, ) # 4. Prepare timesteps self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.inverse_scheduler.timesteps # 6. Rejig the UNet so that we can obtain the cross-attenion maps and # use them for guiding the subsequent image generation. self.unet = prepare_unet(self.unet) # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order with self.progress_bar(total=num_inference_steps - 1) as progress_bar: for i, t in enumerate(timesteps[:-1]): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs={"timestep": t}, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # regularization of the noise prediction with torch.enable_grad(): for _ in range(num_reg_steps): if lambda_auto_corr > 0: for _ in range(num_auto_corr_rolls): var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_ac = self.auto_corr_loss(var_epsilon, generator=generator) l_ac.backward() grad = var.grad.detach() / num_auto_corr_rolls noise_pred = noise_pred - lambda_auto_corr * grad if lambda_kl > 0: var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) # Derive epsilon from model output before regularizing to IID standard normal var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) l_kld = self.kl_divergence(var_epsilon) l_kld.backward() grad = var.grad.detach() noise_pred = noise_pred - lambda_kl * grad noise_pred = noise_pred.detach() # compute the previous noisy sample x_t -> x_t-1 latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) inverted_latents = latents.detach().clone() # 8. Post-processing image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (inverted_latents, image) return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py ================================================ # Copyright 2023 Susung Hong and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionSAGPipeline >>> pipe = StableDiffusionSAGPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, sag_scale=0.75).images[0] ``` """ # processes and stores attention probabilities class CrossAttnStoreProcessor: def __init__(self): self.attention_probs = None def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) self.attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(self.attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states # Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, sag_scale: float = 0.75, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. sag_scale (`float`, *optional*, defaults to 0.75): SAG scale as defined in [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance] (https://arxiv.org/abs/2210.00939). `sag_scale` is defined as `s_s` of equation (24) of SAG paper: https://arxiv.org/pdf/2210.00939.pdf. Typically chosen between [0, 1.0] for better quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # and `sag_scale` is` `s` of equation (16) # of the self-attentnion guidance paper: https://arxiv.org/pdf/2210.00939.pdf # `sag_scale = 0` means no self-attention guidance do_self_attention_guidance = sag_scale > 0.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop store_processor = CrossAttnStoreProcessor() self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order map_size = None def get_map_size(module, input, output): nonlocal map_size map_size = output[0].shape[-2:] with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # perform self-attention guidance with the stored self-attentnion map if do_self_attention_guidance: # classifier-free guidance produces two chunks of attention map # and we only use unconditional one according to equation (25) # in https://arxiv.org/pdf/2210.00939.pdf if do_classifier_free_guidance: # DDIM-like prediction of x0 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) # get the stored attention maps uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) # self-attention-based degrading of latents degraded_latents = self.sag_masking( pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t) ) uncond_emb, _ = prompt_embeds.chunk(2) # forward and give guidance degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) else: # DDIM-like prediction of x0 pred_x0 = self.pred_x0(latents, noise_pred, t) # get the stored attention maps cond_attn = store_processor.attention_probs # self-attention-based degrading of latents degraded_latents = self.sag_masking( pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t) ) # forward and give guidance degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample noise_pred += sag_scale * (noise_pred - degraded_pred) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) def sag_masking(self, original_latents, attn_map, map_size, t, eps): # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf bh, hw1, hw2 = attn_map.shape b, latent_channel, latent_h, latent_w = original_latents.shape h = self.unet.config.attention_head_dim if isinstance(h, list): h = h[-1] # Produce attention mask attn_map = attn_map.reshape(b, h, hw1, hw2) attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 attn_mask = ( attn_mask.reshape(b, map_size[0], map_size[1]) .unsqueeze(1) .repeat(1, latent_channel, 1, 1) .type(attn_map.dtype) ) attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) # Blur according to the self-attention mask degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) # Noise it again to match the noise level degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) return degraded_latents # Modified from diffusers.schedulers.scheduling_ddim.DDIMScheduler.step # Note: there are some schedulers that clip or do not return x_0 (PNDMScheduler, DDIMScheduler, etc.) def pred_x0(self, sample, model_output, timestep): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if self.scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif self.scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output # predict V model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," " or `v_prediction`" ) return pred_original_sample def pred_epsilon(self, sample, model_output, timestep): alpha_prod_t = self.scheduler.alphas_cumprod[timestep] beta_prod_t = 1 - alpha_prod_t if self.scheduler.config.prediction_type == "epsilon": pred_eps = model_output elif self.scheduler.config.prediction_type == "sample": pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) elif self.scheduler.config.prediction_type == "v_prediction": pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," " or `v_prediction`" ) return pred_eps # Gaussian blur def gaussian_blur_2d(img, kernel_size, sigma): ksize_half = (kernel_size - 1) * 0.5 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) pdf = torch.exp(-0.5 * (x / sigma).pow(2)) x_kernel = pdf / pdf.sum() x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] img = F.pad(img, padding, mode="reflect") img = F.conv2d(img, kernel2d, groups=img.shape[-3]) return img ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 image = [np.array(i.resize((w, h)))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image super-resolution using Stable Diffusion 2. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. low_res_scheduler ([`SchedulerMixin`]): A scheduler used to add initial noise to the low res conditioning image. It must be an instance of [`DDPMScheduler`]. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ _optional_components = ["watermarker", "safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, low_res_scheduler: DDPMScheduler, scheduler: KarrasDiffusionSchedulers, safety_checker: Optional[Any] = None, feature_extractor: Optional[CLIPImageProcessor] = None, watermarker: Optional[Any] = None, max_noise_level: int = 350, ): super().__init__() if hasattr( vae, "config" ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate is_vae_scaling_factor_set_to_0_08333 = ( hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 ) if not is_vae_scaling_factor_set_to_0_08333: deprecation_message = ( "The configuration file of the vae does not contain `scaling_factor` or it is set to" f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to" " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to" " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging" " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file" ) deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) vae.register_to_config(scaling_factor=0.08333) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, low_res_scheduler=low_res_scheduler, scheduler=scheduler, safety_checker=safety_checker, watermarker=watermarker, feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") self.register_to_config(max_noise_level=max_noise_level) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, nsfw_detected, watermark_detected = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype), ) else: nsfw_detected = None watermark_detected = None if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: self.unet_offload_hook.offload() return image, nsfw_detected, watermark_detected # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def check_inputs( self, prompt, image, noise_level, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, np.ndarray) and not isinstance(image, list) ): raise ValueError( f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" ) # verify batch size of prompt and image are same if image is a list or tensor or numpy array if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): if isinstance(prompt, str): batch_size = 1 else: batch_size = len(prompt) if isinstance(image, list): image_batch_size = len(image) else: image_batch_size = image.shape[0] if batch_size != image_batch_size: raise ValueError( f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." " Please make sure that passed `prompt` matches the batch size of `image`." ) # check noise level if noise_level > self.config.max_noise_level: raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height, width) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be upscaled. * num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: ```py >>> import requests >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import StableDiffusionUpscalePipeline >>> import torch >>> # load model and scheduler >>> model_id = "stabilityai/stable-diffusion-x4-upscaler" >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained( ... model_id, revision="fp16", torch_dtype=torch.float16 ... ) >>> pipeline = pipeline.to("cuda") >>> # let's download an image >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" >>> response = requests.get(url) >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") >>> low_res_img = low_res_img.resize((128, 128)) >>> prompt = "a white cat" >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] >>> upscaled_image.save("upsampled_cat.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs self.check_inputs( prompt, image, noise_level, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) if image is None: raise ValueError("`image` input cannot be undefined.") # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image image = self.image_processor.preprocess(image) image = image.to(dtype=prompt_embeds.dtype, device=device) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Add noise to image noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) image = self.low_res_scheduler.add_noise(image, noise, noise_level) batch_multiplier = 2 if do_classifier_free_guidance else 1 image = torch.cat([image] * batch_multiplier * num_images_per_prompt) noise_level = torch.cat([noise_level] * image.shape[0]) # 6. Prepare latent variables height, width = image.shape[2:] num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] if num_channels_latents + num_channels_image != self.unet.config.in_channels: raise ValueError( f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = torch.cat([latent_model_input, image], dim=1) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, class_labels=noise_level, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 10. Post-processing # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() # post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # 11. Apply watermark if output_type == "pil" and self.watermarker is not None: image = self.watermarker.apply_watermark(image) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableUnCLIPPipeline >>> pipe = StableUnCLIPPipeline.from_pretrained( ... "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16 ... ) # TODO update model path >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> images = pipe(prompt).images >>> images[0].save("astronaut_horse.png") ``` """ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): """ Pipeline for text-to-image generation using stable unCLIP. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: prior_tokenizer ([`CLIPTokenizer`]): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior_text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. prior_scheduler ([`KarrasDiffusionSchedulers`]): Scheduler used in the prior denoising process. image_normalizer ([`StableUnCLIPImageNormalizer`]): Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image embeddings after the noise has been applied. image_noising_scheduler ([`KarrasDiffusionSchedulers`]): Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined by `noise_level` in `StableUnCLIPPipeline.__call__`. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). text_encoder ([`CLIPTextModel`]): Frozen text-encoder. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`KarrasDiffusionSchedulers`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. """ # prior components prior_tokenizer: CLIPTokenizer prior_text_encoder: CLIPTextModelWithProjection prior: PriorTransformer prior_scheduler: KarrasDiffusionSchedulers # image noising components image_normalizer: StableUnCLIPImageNormalizer image_noising_scheduler: KarrasDiffusionSchedulers # regular denoising components tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL def __init__( self, # prior components prior_tokenizer: CLIPTokenizer, prior_text_encoder: CLIPTextModelWithProjection, prior: PriorTransformer, prior_scheduler: KarrasDiffusionSchedulers, # image noising components image_normalizer: StableUnCLIPImageNormalizer, image_noising_scheduler: KarrasDiffusionSchedulers, # regular denoising components tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, # vae vae: AutoencoderKL, ): super().__init__() self.register_modules( prior_tokenizer=prior_tokenizer, prior_text_encoder=prior_text_encoder, prior=prior, prior_scheduler=prior_scheduler, image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, vae=vae, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") # TODO: self.prior.post_process_latents and self.image_noiser.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list models = [ self.prior_text_encoder, self.text_encoder, self.unet, self.vae, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder def _encode_prior_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None, ): if text_model_output is None: batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.prior_tokenizer( prompt, padding="max_length", max_length=self.prior_tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.prior_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.prior_tokenizer.batch_decode( untruncated_ids[:, self.prior_tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.prior_tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.prior_tokenizer.model_max_length] prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device)) prompt_embeds = prior_text_encoder_output.text_embeds prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state else: batch_size = text_model_output[0].shape[0] prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1] text_mask = text_attention_mask prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size uncond_input = self.prior_tokenizer( uncond_tokens, padding="max_length", max_length=self.prior_tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_prior_text_encoder_output = self.prior_text_encoder( uncond_input.input_ids.to(device) ) negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds uncond_prior_text_encoder_hidden_states = ( negative_prompt_embeds_prior_text_encoder_output.last_hidden_state ) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_prior_text_encoder_hidden_states.shape[1] uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat( 1, num_images_per_prompt, 1 ) uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prior_text_encoder_hidden_states = torch.cat( [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states] ) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, prior_text_encoder_hidden_states, text_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs with prepare_extra_step_kwargs->prepare_prior_extra_step_kwargs, scheduler->prior_scheduler def prepare_prior_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the prior_scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.prior_scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, noise_level, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." ) if prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." ) if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def noise_image_embeddings( self, image_embeds: torch.Tensor, noise_level: int, noise: Optional[torch.FloatTensor] = None, generator: Optional[torch.Generator] = None, ): """ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher `noise_level` increases the variance in the final un-noised images. The noise is applied in two ways 1. A noise schedule is applied directly to the embeddings 2. A vector of sinusoidal time embeddings are appended to the output. In both cases, the amount of noise is controlled by the same `noise_level`. The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. """ if noise is None: noise = randn_tensor( image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype ) noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) self.image_normalizer.to(image_embeds.device) image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) image_embeds = self.image_normalizer.unscale(image_embeds) noise_level = get_timestep_embedding( timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 ) # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, # but we might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. noise_level = noise_level.to(image_embeds.dtype) image_embeds = torch.cat((image_embeds, noise_level), 1) return image_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, # regular denoising process args prompt: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 20, guidance_scale: float = 10.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 0, # prior args prior_num_inference_steps: int = 25, prior_guidance_scale: float = 4.0, prior_latents: Optional[torch.FloatTensor] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 20): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 10.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to `0`): The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. prior_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps in the prior denoising process. More denoising steps usually lead to a higher quality image at the expense of slower inference. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale for the prior denoising process as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. prior_latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image embedding generation in the prior denoising process. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, height=height, width=width, callback_steps=callback_steps, noise_level=noise_level, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] batch_size = batch_size * num_images_per_prompt device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. prior_do_classifier_free_guidance = prior_guidance_scale > 1.0 # 3. Encode input prompt prior_prompt_embeds, prior_text_encoder_hidden_states, prior_text_mask = self._encode_prior_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=prior_do_classifier_free_guidance, ) # 4. Prepare prior timesteps self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) prior_timesteps_tensor = self.prior_scheduler.timesteps # 5. Prepare prior latent variables embedding_dim = self.prior.config.embedding_dim prior_latents = self.prepare_latents( (batch_size, embedding_dim), prior_prompt_embeds.dtype, device, generator, prior_latents, self.prior_scheduler, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline prior_extra_step_kwargs = self.prepare_prior_extra_step_kwargs(generator, eta) # 7. Prior denoising loop for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([prior_latents] * 2) if prior_do_classifier_free_guidance else prior_latents latent_model_input = self.prior_scheduler.scale_model_input(latent_model_input, t) predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prior_prompt_embeds, encoder_hidden_states=prior_text_encoder_hidden_states, attention_mask=prior_text_mask, ).predicted_image_embedding if prior_do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) prior_latents = self.prior_scheduler.step( predicted_image_embedding, timestep=t, sample=prior_latents, **prior_extra_step_kwargs, return_dict=False, )[0] if callback is not None and i % callback_steps == 0: callback(i, t, prior_latents) prior_latents = self.prior.post_process_latents(prior_latents) image_embeds = prior_latents # done prior # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 8. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 9. Prepare image embeddings image_embeds = self.noise_image_embeddings( image_embeds=image_embeds, noise_level=noise_level, generator=generator, ) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) # 10. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 11. Prepare latent variables num_channels_latents = self.unet.config.in_channels shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) latents = self.prepare_latents( shape=shape, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=latents, scheduler=self.scheduler, ) # 12. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 13. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=image_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Union import PIL import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.utils.import_utils import is_accelerate_available from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import StableUnCLIPImg2ImgPipeline >>> pipe = StableUnCLIPImg2ImgPipeline.from_pretrained( ... "fusing/stable-unclip-2-1-l-img2img", torch_dtype=torch.float16 ... ) # TODO update model path >>> pipe = pipe.to("cuda") >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) >>> prompt = "A fantasy landscape, trending on artstation" >>> images = pipe(prompt, init_image).images >>> images[0].save("fantasy_landscape.png") ``` """ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): """ Pipeline for text-guided image to image generation using stable unCLIP. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: feature_extractor ([`CLIPImageProcessor`]): Feature extractor for image pre-processing before being encoded. image_encoder ([`CLIPVisionModelWithProjection`]): CLIP vision model for encoding images. image_normalizer ([`StableUnCLIPImageNormalizer`]): Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image embeddings after the noise has been applied. image_noising_scheduler ([`KarrasDiffusionSchedulers`]): Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined by `noise_level` in `StableUnCLIPPipeline.__call__`. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). text_encoder ([`CLIPTextModel`]): Frozen text-encoder. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`KarrasDiffusionSchedulers`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. """ # image encoding components feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection # image noising components image_normalizer: StableUnCLIPImageNormalizer image_noising_scheduler: KarrasDiffusionSchedulers # regular denoising components tokenizer: CLIPTokenizer text_encoder: CLIPTextModel unet: UNet2DConditionModel scheduler: KarrasDiffusionSchedulers vae: AutoencoderKL def __init__( self, # image encoding components feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, # image noising components image_normalizer: StableUnCLIPImageNormalizer, image_noising_scheduler: KarrasDiffusionSchedulers, # regular denoising components tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, # vae vae: AutoencoderKL, ): super().__init__() self.register_modules( feature_extractor=feature_extractor, image_encoder=image_encoder, image_normalizer=image_normalizer, image_noising_scheduler=image_noising_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, vae=vae, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") # TODO: self.image_normalizer.{scale,unscale} are not covered by the offload hooks, so they fails if added to the list models = [ self.image_encoder, self.text_encoder, self.unet, self.vae, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def _encode_image( self, image, device, batch_size, num_images_per_prompt, do_classifier_free_guidance, noise_level, generator, image_embeds, ): dtype = next(self.image_encoder.parameters()).dtype if isinstance(image, PIL.Image.Image): # the image embedding should repeated so it matches the total batch size of the prompt repeat_by = batch_size else: # assume the image input is already properly batched and just needs to be repeated so # it matches the num_images_per_prompt. # # NOTE(will) this is probably missing a few number of side cases. I.e. batched/non-batched # `image_embeds`. If those happen to be common use cases, let's think harder about # what the expected dimensions of inputs should be and how we handle the encoding. repeat_by = num_images_per_prompt if image_embeds is None: if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeds = self.image_encoder(image).image_embeds image_embeds = self.noise_image_embeddings( image_embeds=image_embeds, noise_level=noise_level, generator=generator, ) # duplicate image embeddings for each generation per prompt, using mps friendly method image_embeds = image_embeds.unsqueeze(1) bs_embed, seq_len, _ = image_embeds.shape image_embeds = image_embeds.repeat(1, repeat_by, 1) image_embeds = image_embeds.view(bs_embed * repeat_by, seq_len, -1) image_embeds = image_embeds.squeeze(1) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeds) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeds = torch.cat([negative_prompt_embeds, image_embeds]) return image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, height, width, callback_steps, noise_level, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, image_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two." ) if prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( "Provide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined." ) if prompt is not None and negative_prompt is not None: if type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps: raise ValueError( f"`noise_level` must be between 0 and {self.image_noising_scheduler.config.num_train_timesteps - 1}, inclusive." ) if image is not None and image_embeds is not None: raise ValueError( "Provide either `image` or `image_embeds`. Please make sure to define only one of the two." ) if image is None and image_embeds is None: raise ValueError( "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." ) if image is not None: if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_unclip.StableUnCLIPPipeline.noise_image_embeddings def noise_image_embeddings( self, image_embeds: torch.Tensor, noise_level: int, noise: Optional[torch.FloatTensor] = None, generator: Optional[torch.Generator] = None, ): """ Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher `noise_level` increases the variance in the final un-noised images. The noise is applied in two ways 1. A noise schedule is applied directly to the embeddings 2. A vector of sinusoidal time embeddings are appended to the output. In both cases, the amount of noise is controlled by the same `noise_level`. The embeddings are normalized before the noise is applied and un-normalized after the noise is applied. """ if noise is None: noise = randn_tensor( image_embeds.shape, generator=generator, device=image_embeds.device, dtype=image_embeds.dtype ) noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) self.image_normalizer.to(image_embeds.device) image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) image_embeds = self.image_normalizer.unscale(image_embeds) noise_level = get_timestep_embedding( timesteps=noise_level, embedding_dim=image_embeds.shape[-1], flip_sin_to_cos=True, downscale_freq_shift=0 ) # `get_timestep_embeddings` does not contain any weights and will always return f32 tensors, # but we might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. noise_level = noise_level.to(image_embeds.dtype) image_embeds = torch.cat((image_embeds, noise_level), 1) return image_embeds @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, image: Union[torch.FloatTensor, PIL.Image.Image] = None, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 20, guidance_scale: float = 10, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, noise_level: int = 0, image_embeds: Optional[torch.FloatTensor] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be used or prompt is initialized to `""`. image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the latents in the denoising process such as in the standard stable diffusion text guided image variation process. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 20): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 10.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). noise_level (`int`, *optional*, defaults to `0`): The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in the final un-noised images. See `StableUnCLIPPipeline.noise_image_embeddings` for details. image_embeds (`torch.FloatTensor`, *optional*): Pre-generated CLIP embeddings to condition the unet on. Note that these are not latents to be used in the denoising process. If you want to provide pre-generated latents, pass them to `__call__` as `latents`. Examples: Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor if prompt is None and prompt_embeds is None: prompt = len(image) * [""] if isinstance(image, list) else "" # 1. Check inputs. Raise error if not correct self.check_inputs( prompt=prompt, image=image, height=height, width=width, callback_steps=callback_steps, noise_level=noise_level, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, image_embeds=image_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] batch_size = batch_size * num_images_per_prompt device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Encoder input image noise_level = torch.tensor([noise_level], device=device) image_embeds = self._encode_image( image=image, device=device, batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, noise_level=noise_level, generator=generator, image_embeds=image_embeds, ) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size=batch_size, num_channels_latents=num_channels_latents, height=height, width=width, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=latents, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=image_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 9. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/safety_checker.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) class StableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPConfig): super().__init__(config) self.vision_model = CLIPVisionModel(config.vision_config) self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) @torch.no_grad() def forward(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() result = [] batch_size = image_embeds.shape[0] for i in range(batch_size): result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 for concept_idx in range(len(special_cos_dist[0])): concept_cos = special_cos_dist[i][concept_idx] concept_threshold = self.special_care_embeds_weights[concept_idx].item() result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["special_scores"][concept_idx] > 0: result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) adjustment = 0.01 for concept_idx in range(len(cos_dist[0])): concept_cos = cos_dist[i][concept_idx] concept_threshold = self.concept_embeds_weights[concept_idx].item() result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["concept_scores"][concept_idx] > 0: result_img["bad_concepts"].append(concept_idx) result.append(result_img) has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: if torch.is_tensor(images) or torch.is_tensor(images[0]): images[idx] = torch.zeros_like(images[idx]) # black image else: images[idx] = np.zeros(images[idx].shape) # black image if any(has_nsfw_concepts): logger.warning( "Potential NSFW content was detected in one or more images. A black image will be returned instead." " Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts @torch.no_grad() def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) cos_dist = cosine_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nsfw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment # special_scores = special_scores.round(decimals=3) special_care = torch.any(special_scores > 0, dim=1) special_adjustment = special_care * 0.01 special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment # concept_scores = concept_scores.round(decimals=3) has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) images[has_nsfw_concepts] = 0.0 # black image return images, has_nsfw_concepts ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/safety_checker_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import jax import jax.numpy as jnp from flax import linen as nn from flax.core.frozen_dict import FrozenDict from transformers import CLIPConfig, FlaxPreTrainedModel from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule def jax_cosine_distance(emb_1, emb_2, eps=1e-12): norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T return jnp.matmul(norm_emb_1, norm_emb_2.T) class FlaxStableDiffusionSafetyCheckerModule(nn.Module): config: CLIPConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) self.special_care_embeds = self.param( "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) ) self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) def __call__(self, clip_input): pooled_output = self.vision_model(clip_input)[1] image_embeds = self.visual_projection(pooled_output) special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign image inputs adjustment = 0.0 special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment special_scores = jnp.round(special_scores, 3) is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) # Use a lower threshold if an image has any special care concept special_adjustment = is_special_care * 0.01 concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment concept_scores = jnp.round(concept_scores, 3) has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) return has_nsfw_concepts class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): config_class = CLIPConfig main_input_name = "clip_input" module_class = FlaxStableDiffusionSafetyCheckerModule def __init__( self, config: CLIPConfig, input_shape: Optional[Tuple] = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, **kwargs, ): if input_shape is None: input_shape = (1, 224, 224, 3) module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor clip_input = jax.random.normal(rng, input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, clip_input)["params"] return random_params def __call__( self, clip_input, params: dict = None, ): clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) return self.module.apply( {"params": params or self.params}, jnp.array(clip_input, dtype=jnp.float32), rngs={}, ) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models.modeling_utils import ModelMixin class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): """ This class is used to hold the mean and standard deviation of the CLIP embedder used in stable unCLIP. It is used to normalize the image embeddings before the noise is applied and un-normalize the noised image embeddings. """ @register_to_config def __init__( self, embedding_dim: int = 768, ): super().__init__() self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) self.std = nn.Parameter(torch.ones(1, embedding_dim)) def to( self, torch_device: Optional[Union[str, torch.device]] = None, torch_dtype: Optional[torch.dtype] = None, ): self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype)) self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype)) return self def scale(self, embeds): embeds = (embeds - self.mean) * 1.0 / self.std return embeds def unscale(self, embeds): embeds = (embeds * self.std) + self.mean return embeds ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion_safe/__init__.py ================================================ from dataclasses import dataclass from enum import Enum from typing import List, Optional, Union import numpy as np import PIL from PIL import Image from ...utils import BaseOutput, is_torch_available, is_transformers_available @dataclass class SafetyConfig(object): WEAK = { "sld_warmup_steps": 15, "sld_guidance_scale": 20, "sld_threshold": 0.0, "sld_momentum_scale": 0.0, "sld_mom_beta": 0.0, } MEDIUM = { "sld_warmup_steps": 10, "sld_guidance_scale": 1000, "sld_threshold": 0.01, "sld_momentum_scale": 0.3, "sld_mom_beta": 0.4, } STRONG = { "sld_warmup_steps": 7, "sld_guidance_scale": 2000, "sld_threshold": 0.025, "sld_momentum_scale": 0.5, "sld_mom_beta": 0.7, } MAX = { "sld_warmup_steps": 0, "sld_guidance_scale": 5000, "sld_threshold": 1.0, "sld_momentum_scale": 0.5, "sld_mom_beta": 0.7, } @dataclass class StableDiffusionSafePipelineOutput(BaseOutput): """ Output class for Safe Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, or `None` if safety checking could not be performed. images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" (nsfw) content, or `None` if no safety check was performed or no images were flagged. applied_safety_concept (`str`) The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled """ images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] applied_safety_concept: Optional[str] if is_transformers_available() and is_torch_available(): from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe from .safety_checker import SafeStableDiffusionSafetyChecker ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py ================================================ import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionSafePipelineOutput from .safety_checker import SafeStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableDiffusionPipelineSafe(DiffusionPipeline): r""" Pipeline for text-to-image generation using Safe Latent Diffusion. The implementation is based on the [`StableDiffusionPipeline`] This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: SafeStableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() safety_concept: Optional[str] = ( "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" " abuse, brutality, cruelty" ) if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self._safety_text_concept = safety_concept self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) @property def safety_concept(self): r""" Getter method for the safety concept used with SLD Returns: `str`: The text describing the safety concept """ return self._safety_text_concept @safety_concept.setter def safety_concept(self, concept): r""" Setter method for the safety concept used with SLD Args: concept (`str`): The text of the new safety concept """ self._safety_text_concept = concept def enable_sequential_cpu_offload(self): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device("cuda") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # Encode the safety concept text if enable_safety_guidance: safety_concept_input = self.tokenizer( [self._safety_text_concept], padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] # duplicate safety embeddings for each generation per prompt, using mps friendly method seq_len = safety_embeddings.shape[1] safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance + sld, we need to do three forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing three forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, safety_embeddings]) else: # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def run_safety_checker(self, image, device, dtype, enable_safety_guidance): if self.safety_checker is not None: images = image.copy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) flagged_images = np.zeros((2, *image.shape[1:])) if any(has_nsfw_concept): logger.warning( "Potential NSFW content was detected in one or more images. A black image will be returned" " instead." f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}" ) for idx, has_nsfw_concept in enumerate(has_nsfw_concept): if has_nsfw_concept: flagged_images[idx] = images[idx] image[idx] = np.zeros(image[idx].shape) # black image else: has_nsfw_concept = None flagged_images = None return image, has_nsfw_concept, flagged_images # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def perform_safety_guidance( self, enable_safety_guidance, safety_momentum, noise_guidance, noise_pred_out, i, sld_guidance_scale, sld_warmup_steps, sld_threshold, sld_momentum_scale, sld_mom_beta, ): # Perform SLD guidance if enable_safety_guidance: if safety_momentum is None: safety_momentum = torch.zeros_like(noise_guidance) noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] noise_pred_safety_concept = noise_pred_out[2] # Equation 6 scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) # Equation 6 safety_concept_scale = torch.where( (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale ) # Equation 4 noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) # Equation 7 noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum # Equation 8 safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety if i >= sld_warmup_steps: # Warmup # Equation 3 noise_guidance = noise_guidance - noise_guidance_safety return noise_guidance, safety_momentum @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, sld_guidance_scale: Optional[float] = 1000, sld_warmup_steps: Optional[int] = 10, sld_threshold: Optional[float] = 0.01, sld_momentum_scale: Optional[float] = 0.3, sld_mom_beta: Optional[float] = 0.4, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. sld_guidance_scale (`float`, *optional*, defaults to 1000): Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). `sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be disabled. sld_warmup_steps (`int`, *optional*, defaults to 10): Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than `sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). sld_threshold (`float`, *optional*, defaults to 0.01): Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold` is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). sld_momentum_scale (`float`, *optional*, defaults to 0.3): Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0 momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). sld_mom_beta (`float`, *optional*, defaults to 0.4): Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance if not enable_safety_guidance: warnings.warn("Safety checker disabled!") # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) safety_momentum = None num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * (3 if enable_safety_guidance else 2)) if do_classifier_free_guidance else latents ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] # default classifier free guidance noise_guidance = noise_pred_text - noise_pred_uncond # Perform SLD guidance if enable_safety_guidance: if safety_momentum is None: safety_momentum = torch.zeros_like(noise_guidance) noise_pred_safety_concept = noise_pred_out[2] # Equation 6 scale = torch.clamp( torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 ) # Equation 6 safety_concept_scale = torch.where( (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale, ) # Equation 4 noise_guidance_safety = torch.mul( (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale ) # Equation 7 noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum # Equation 8 safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety if i >= sld_warmup_steps: # Warmup # Equation 3 noise_guidance = noise_guidance - noise_guidance_safety noise_pred = noise_pred_uncond + guidance_scale * noise_guidance # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Post-processing image = self.decode_latents(latents) # 9. Run safety checker image, has_nsfw_concept, flagged_images = self.run_safety_checker( image, device, prompt_embeds.dtype, enable_safety_guidance ) # 10. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) if flagged_images is not None: flagged_images = self.numpy_to_pil(flagged_images) if not return_dict: return ( image, has_nsfw_concept, self._safety_text_concept if enable_safety_guidance else None, flagged_images, ) return StableDiffusionSafePipelineOutput( images=image, nsfw_content_detected=has_nsfw_concept, applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, unsafe_images=flagged_images, ) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion_safe/safety_checker.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel from ...utils import logging logger = logging.get_logger(__name__) def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) class SafeStableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: CLIPConfig): super().__init__(config) self.vision_model = CLIPVisionModel(config.vision_config) self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) @torch.no_grad() def forward(self, clip_input, images): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() result = [] batch_size = image_embeds.shape[0] for i in range(batch_size): result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} # increase this value to create a stronger `nfsw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 for concept_idx in range(len(special_cos_dist[0])): concept_cos = special_cos_dist[i][concept_idx] concept_threshold = self.special_care_embeds_weights[concept_idx].item() result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["special_scores"][concept_idx] > 0: result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) adjustment = 0.01 for concept_idx in range(len(cos_dist[0])): concept_cos = cos_dist[i][concept_idx] concept_threshold = self.concept_embeds_weights[concept_idx].item() result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) if result_img["concept_scores"][concept_idx] > 0: result_img["bad_concepts"].append(concept_idx) result.append(result_img) has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] return images, has_nsfw_concepts @torch.no_grad() def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) cos_dist = cosine_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nsfw` filter # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment # special_scores = special_scores.round(decimals=3) special_care = torch.any(special_scores > 0, dim=1) special_adjustment = special_care * 0.01 special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment # concept_scores = concept_scores.round(decimals=3) has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) return images, has_nsfw_concepts ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion_xl/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import PIL from ...utils import BaseOutput, is_invisible_watermark_available, is_torch_available, is_transformers_available @dataclass class StableDiffusionXLPipelineOutput(BaseOutput): """ Output class for Stable Diffusion pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ images: Union[List[PIL.Image.Image], np.ndarray] if is_transformers_available() and is_torch_available() and is_invisible_watermark_available(): from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionXLPipeline >>> pipe = StableDiffusionXLPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size self.watermark = StableDiffusionXLWatermarker() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) model_sequence = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) model_sequence.extend([self.unet, self.vae]) hook = None for cpu_offloaded_model in model_sequence: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def encode_prompt( self, prompt, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder( text_input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt negative_prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1 ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) return add_time_ids @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): TODO crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): TODO target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): TODO Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents return StableDiffusionXLPipelineOutput(images=image) image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionXLImg2ImgPipeline >>> from diffusers.utils import load_image >>> pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( ... "stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" >>> init_image = load_image(url).convert("RGB") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt, image=init_image).images[0] ``` """ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ _optional_components = ["tokenizer", "text_encoder"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.watermark = StableDiffusionXLWatermarker() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]: cpu_offload(cpu_offloaded_model, device) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_model_cpu_offload def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) model_sequence = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) model_sequence.extend([self.unet, self.vae]) hook = None for cpu_offloaded_model in model_sequence: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, prompt, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] text_encoders = ( [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder( text_input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt negative_prompt_embeds_list = [] for tokenizer, text_encoder in zip(tokenizers, text_encoders): # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1 ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) bs_embed = pooled_prompt_embeds.shape[0] pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") torch.cuda.empty_cache() image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: # make sure the VAE is in float32 mode, as it overflows in float16 image = image.float() self.vae.to(dtype=torch.float32) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) self.vae.to(dtype) init_latents = init_latents.to(dtype) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." ) elif expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) return add_time_ids, add_neg_time_ids @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, strength: float = 0.3, num_inference_steps: int = 50, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): The image(s) to modify with the pipeline. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). guidance_rescale (`float`, *optional*, defaults to 0.7): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): TODO crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): TODO target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): TODO aesthetic_score (`float`, *optional*, defaults to 6.0): TODO negative_aesthetic_score (`float`, *optional*, defaults to 2.5): TDOO Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image image = self.image_processor.preprocess(image) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) height, width = latents.shape[-2:] height = height * self.vae_scale_factor width = width * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 8. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype=prompt_embeds.dtype, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) use_torch_2_0_or_xformers = isinstance( self.vae.decoder.mid_block.attentions[0].processor, ( AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0, ), ) # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory if use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents return StableDiffusionXLPipelineOutput(images=image) image = self.watermark.apply_watermark(image) image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image,) return StableDiffusionXLPipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/stable_diffusion_xl/watermark.py ================================================ import numpy as np import torch from imwatermark import WatermarkEncoder # Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66 WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] class StableDiffusionXLWatermarker: def __init__(self): self.watermark = WATERMARK_BITS self.encoder = WatermarkEncoder() self.encoder.set_watermark("bits", self.watermark) def apply_watermark(self, images: torch.FloatTensor): # can't encode images that are smaller than 256 if images.shape[-1] < 256: return images images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() images = [self.encoder.encode(image, "dwtDct") for image in images] images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) return images ================================================ FILE: 6DoF/diffusers/pipelines/stochastic_karras_ve/__init__.py ================================================ from .pipeline_stochastic_karras_ve import KarrasVePipeline ================================================ FILE: 6DoF/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch from ...models import UNet2DModel from ...schedulers import KarrasVeScheduler from ...utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput class KarrasVePipeline(DiffusionPipeline): r""" Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 Parameters: unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`KarrasVeScheduler`]): Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image. """ # add type hints for linting unet: UNet2DModel scheduler: KarrasVeScheduler def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__( self, batch_size: int = 1, num_inference_steps: int = 50, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: r""" Args: batch_size (`int`, *optional*, defaults to 1): The number of images to generate. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) model = self.unet # sample x_0 ~ N(0, sigma_0^2 * I) sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # here sigma_t == t_i from the paper sigma = self.scheduler.schedule[t] sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 # 1. Select temporarily increased noise level sigma_hat # 2. Add new noise to move from sample_i to sample_hat sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator) # 3. Predict the noise residual given the noise magnitude `sigma_hat` # The model inputs and output are adjusted by following eq. (213) in [1]. model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample # 4. Evaluate dx/dt at sigma_hat # 5. Take Euler step from sigma to sigma_prev step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat) if sigma_prev != 0: # 6. Apply 2nd order correction # The model inputs and output are adjusted by following eq. (213) in [1]. model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample step_output = self.scheduler.step_correct( model_output, sigma_hat, sigma_prev, sample_hat, step_output.prev_sample, step_output["derivative"], ) sample = step_output.prev_sample sample = (sample / 2 + 0.5).clamp(0, 1) image = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/text_to_video_synthesis/__init__.py ================================================ from dataclasses import dataclass from typing import List, Optional, Union import numpy as np import torch from ...utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available @dataclass class TextToVideoSDPipelineOutput(BaseOutput): """ Output class for text to video pipelines. Args: frames (`List[np.ndarray]` or `torch.FloatTensor`) List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list denotes the video length i.e., the number of frames. """ frames: Union[List[np.ndarray], torch.FloatTensor] try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_text_to_video_synth import TextToVideoSDPipeline from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline # noqa: F401 from .pipeline_text_to_video_zero import TextToVideoZeroPipeline ================================================ FILE: 6DoF/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import TextToVideoSDPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import TextToVideoSDPipeline >>> from diffusers.utils import export_to_video >>> pipe = TextToVideoSDPipeline.from_pretrained( ... "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16" ... ) >>> pipe.enable_model_cpu_offload() >>> prompt = "Spiderman is surfing" >>> video_frames = pipe(prompt).frames >>> video_path = export_to_video(video_frames) >>> video_path ``` """ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 # reshape to ncfhw mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) # unnormalize back to [0,1] video = video.mul_(std).add_(mean) video.clamp_(0, 1) # prepare the final outputs i, c, f, h, w = video.shape images = video.permute(2, 3, 0, 4, 1).reshape( f, h, i * w, c ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c return images class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Same as Stable Diffusion 2. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet3DConditionModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample video = ( image[None, :] .reshape( ( batch_size, num_frames, -1, ) + image.shape[2:] ) .permute(0, 2, 1, 3, 4) ) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): shape = ( batch_size, num_channels_latents, num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_frames: int = 16, num_inference_steps: int = 50, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated video. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated video. num_frames (`int`, *optional*, defaults to 16): The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds amounts to 2 seconds of video. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated frames. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor num_images_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, num_frames, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # reshape latents bsz, channel, frames, width, height = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # reshape latents back latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if output_type == "latent": return TextToVideoSDPipelineOutput(frames=latents) video_tensor = self.decode_latents(latents) if output_type == "pt": video = video_tensor else: video = tensor2vid(video_tensor) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (video,) return TextToVideoSDPipelineOutput(frames=video) ================================================ FILE: 6DoF/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch from transformers import CLIPTextModel, CLIPTokenizer from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import TextToVideoSDPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler >>> from diffusers.utils import export_to_video >>> pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16) >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.to("cuda") >>> prompt = "spiderman running in the desert" >>> video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames >>> # safe low-res video >>> video_path = export_to_video(video_frames, output_video_path="./video_576_spiderman.mp4") >>> # let's offload the text-to-image model >>> pipe.to("cpu") >>> # and load the image-to-image model >>> pipe = DiffusionPipeline.from_pretrained( ... "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/15" ... ) >>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # The VAE consumes A LOT of memory, let's make sure we run it in sliced mode >>> pipe.vae.enable_slicing() >>> # now let's upscale it >>> video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames] >>> # and denoise it >>> video_frames = pipe(prompt, video=video, strength=0.6).frames >>> video_path = export_to_video(video_frames, output_video_path="./video_1024_spiderman.mp4") >>> video_path ``` """ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 # reshape to ncfhw mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1) std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1) # unnormalize back to [0,1] video = video.mul_(std).add_(mean) video.clamp_(0, 1) # prepare the final outputs i, c, f, h, w = video.shape images = video.permute(2, 3, 0, 4, 1).reshape( f, h, i * w, c ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c) images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames) images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c return images def preprocess_video(video): supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image) if isinstance(video, supported_formats): video = [video] elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): raise ValueError( f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}" ) if isinstance(video[0], PIL.Image.Image): video = [np.array(frame) for frame in video] if isinstance(video[0], np.ndarray): video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) if video.dtype == np.uint8: video = np.array(video).astype(np.float32) / 255.0 if video.ndim == 4: video = video[None, ...] video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3)) elif isinstance(video[0], torch.Tensor): video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0) # don't need any preprocess if the video is latents channel = video.shape[1] if channel == 4: return video # move channels before num_frames video = video.permute(0, 2, 1, 3, 4) # normalize video video = 2.0 * video - 1.0 return video class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Same as Stable Diffusion 2. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet3DConditionModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.vae, self.unet]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) image = self.vae.decode(latents).sample video = ( image[None, :] .reshape( ( batch_size, num_frames, -1, ) + image.shape[2:] ) .permute(0, 2, 1, 3, 4) ) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, video, timestep, batch_size, dtype, device, generator=None): video = video.to(device=device, dtype=dtype) # change from (b, c, f, h, w) -> (b * f, c, w, h) bsz, channel, frames, width, height = video.shape video = video.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) if video.shape[1] == 4: init_latents = video else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): init_latents = [ self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(video).latent_dist.sample(generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents latents = latents[None, :].reshape((bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4) return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, video: Union[List[np.ndarray], torch.FloatTensor] = None, strength: float = 0.6, num_inference_steps: int = 50, guidance_scale: float = 15.0, negative_prompt: Optional[Union[str, List[str]]] = None, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. instead. video: (`List[np.ndarray]` or `torch.FloatTensor`): `video` frames or tensor representing a video batch, that will be used as the starting point for the process. Can also accpet video latents as `image`, if passing latents directly, it will not be encoded again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will be maximum and the denoising process will run for the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality videos at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, usually at the expense of lower video quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape `(batch_size, num_channel, num_frames, height, width)`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated frames. """ # 0. Default height and width to unet num_images_per_prompt = 1 # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) # 4. Preprocess video video = preprocess_video(video) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # reshape latents bsz, channel, frames, width, height = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # reshape latents back latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if output_type == "latent": return TextToVideoSDPipelineOutput(frames=latents) if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") video_tensor = self.decode_latents(latents) if output_type == "pt": video = video_tensor else: video = tensor2vid(video_tensor) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (video,) return TextToVideoSDPipelineOutput(frames=video) ================================================ FILE: 6DoF/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py ================================================ import copy from dataclasses import dataclass from typing import Callable, List, Optional, Union import numpy as np import PIL import torch import torch.nn.functional as F from torch.nn.functional import grid_sample from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import BaseOutput def rearrange_0(tensor, f): F, C, H, W = tensor.size() tensor = torch.permute(torch.reshape(tensor, (F // f, f, C, H, W)), (0, 2, 1, 3, 4)) return tensor def rearrange_1(tensor): B, C, F, H, W = tensor.size() return torch.reshape(torch.permute(tensor, (0, 2, 1, 3, 4)), (B * F, C, H, W)) def rearrange_3(tensor, f): F, D, C = tensor.size() return torch.reshape(tensor, (F // f, f, D, C)) def rearrange_4(tensor): B, F, D, C = tensor.size() return torch.reshape(tensor, (B * F, D, C)) class CrossFrameAttnProcessor: """ Cross frame attention processor. Each frame attends the first frame. Args: batch_size: The number that represents actual batch size, other than the frames. For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to 2, due to classifier-free guidance. """ def __init__(self, batch_size=2): self.batch_size = batch_size def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Cross Frame Attention if not is_cross_attention: video_length = key.size()[0] // self.batch_size first_frame_index = [0] * video_length # rearrange keys to have batch and frames in the 1st and 2nd dims respectively key = rearrange_3(key, video_length) key = key[:, first_frame_index] # rearrange values to have batch and frames in the 1st and 2nd dims respectively value = rearrange_3(value, video_length) value = value[:, first_frame_index] # rearrange back to original shape key = rearrange_4(key) value = rearrange_4(value) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class CrossFrameAttnProcessor2_0: """ Cross frame attention processor with scaled_dot_product attention of Pytorch 2.0. Args: batch_size: The number that represents actual batch size, other than the frames. For example, calling unet with a single prompt and num_images_per_prompt=1, batch_size should be equal to 2, due to classifier-free guidance. """ def __init__(self, batch_size=2): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.batch_size = batch_size def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) query = attn.to_q(hidden_states) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # Cross Frame Attention if not is_cross_attention: video_length = key.size()[0] // self.batch_size first_frame_index = [0] * video_length # rearrange keys to have batch and frames in the 1st and 2nd dims respectively key = rearrange_3(key, video_length) key = key[:, first_frame_index] # rearrange values to have batch and frames in the 1st and 2nd dims respectively value = rearrange_3(value, video_length) value = value[:, first_frame_index] # rearrange back to original shape key = rearrange_4(key) value = rearrange_4(value) head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states @dataclass class TextToVideoPipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] def coords_grid(batch, ht, wd, device): # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) def warp_single_latent(latent, reference_flow): """ Warp latent of a single frame with given flow Args: latent: latent code of a single frame reference_flow: flow which to warp the latent with Returns: warped: warped latent """ _, _, H, W = reference_flow.size() _, _, h, w = latent.size() coords0 = coords_grid(1, H, W, device=latent.device).to(latent.dtype) coords_t0 = coords0 + reference_flow coords_t0[:, 0] /= W coords_t0[:, 1] /= H coords_t0 = coords_t0 * 2.0 - 1.0 coords_t0 = F.interpolate(coords_t0, size=(h, w), mode="bilinear") coords_t0 = torch.permute(coords_t0, (0, 2, 3, 1)) warped = grid_sample(latent, coords_t0, mode="nearest", padding_mode="reflection") return warped def create_motion_field(motion_field_strength_x, motion_field_strength_y, frame_ids, device, dtype): """ Create translation motion field Args: motion_field_strength_x: motion strength along x-axis motion_field_strength_y: motion strength along y-axis frame_ids: indexes of the frames the latents of which are being processed. This is needed when we perform chunk-by-chunk inference device: device dtype: dtype Returns: """ seq_length = len(frame_ids) reference_flow = torch.zeros((seq_length, 2, 512, 512), device=device, dtype=dtype) for fr_idx in range(seq_length): reference_flow[fr_idx, 0, :, :] = motion_field_strength_x * (frame_ids[fr_idx]) reference_flow[fr_idx, 1, :, :] = motion_field_strength_y * (frame_ids[fr_idx]) return reference_flow def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_strength_y, frame_ids, latents): """ Creates translation motion and warps the latents accordingly Args: motion_field_strength_x: motion strength along x-axis motion_field_strength_y: motion strength along y-axis frame_ids: indexes of the frames the latents of which are being processed. This is needed when we perform chunk-by-chunk inference latents: latent codes of frames Returns: warped_latents: warped latents """ motion_field = create_motion_field( motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, frame_ids=frame_ids, device=latents.device, dtype=latents.dtype, ) warped_latents = latents.clone().detach() for i in range(len(warped_latents)): warped_latents[i] = warp_single_latent(latents[i][None], motion_field[i][None]) return warped_latents class TextToVideoZeroPipeline(StableDiffusionPipeline): r""" Pipeline for zero-shot text-to-video generation using Stable Diffusion. This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker ) processor = ( CrossFrameAttnProcessor2_0(batch_size=2) if hasattr(F, "scaled_dot_product_attention") else CrossFrameAttnProcessor(batch_size=2) ) self.unet.set_attn_processor(processor) def forward_loop(self, x_t0, t0, t1, generator): """ Perform ddpm forward process from time t0 to t1. This is the same as adding noise with corresponding variance. Args: x_t0: latent code at time t0 t0: t0 t1: t1 generator: torch.Generator object Returns: x_t1: forward process applied to x_t0 from time t0 to t1. """ eps = torch.randn(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device) alpha_vec = torch.prod(self.scheduler.alphas[t0:t1]) x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps return x_t1 def backward_loop( self, latents, timesteps, prompt_embeds, guidance_scale, callback, callback_steps, num_warmup_steps, extra_step_kwargs, cross_attention_kwargs=None, ): """ Perform backward process given list of time steps Args: latents: Latents at time timesteps[0]. timesteps: time steps, along which to perform backward process. prompt_embeds: Pre-generated text embeddings guidance_scale: Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. extra_step_kwargs: extra_step_kwargs. cross_attention_kwargs: cross_attention_kwargs. num_warmup_steps: number of warmup steps. Returns: latents: latents of backward process output at time timesteps[-1] """ do_classifier_free_guidance = guidance_scale > 1.0 num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order with self.progress_bar(total=num_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) return latents.clone().detach() @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], video_length: Optional[int] = 8, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, motion_field_strength_x: float = 12, motion_field_strength_y: float = 12, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, t0: int = 44, t1: int = 47, frame_ids: Optional[List[int]] = None, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. video_length (`int`, *optional*, defaults to 8): The number of generated video frames height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"numpy"`): The output format of the generated image. Choose between `"latent"` and `"numpy"`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. motion_field_strength_x (`float`, *optional*, defaults to 12): Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. motion_field_strength_y (`float`, *optional*, defaults to 12): Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. t0 (`int`, *optional*, defaults to 44): Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. t1 (`int`, *optional*, defaults to 47): Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1. frame_ids (`List[int]`, *optional*): Indexes of the frames that are being generated. This is used when generating longer videos chunk-by-chunk. Returns: [`~pipelines.text_to_video_synthesis.TextToVideoPipelineOutput`]: The output contains a ndarray of the generated images, when output_type != 'latent', otherwise a latent codes of generated image, and a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ assert video_length > 0 if frame_ids is None: frame_ids = list(range(video_length)) assert len(frame_ids) == video_length assert num_videos_per_prompt == 1 if isinstance(prompt, str): prompt = [prompt] if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] # Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt ) # Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order # Perform the first backward process up to time T_1 x_1_t1 = self.backward_loop( timesteps=timesteps[: -t1 - 1], prompt_embeds=prompt_embeds, latents=latents, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps, ) scheduler_copy = copy.deepcopy(self.scheduler) # Perform the second backward process up to time T_0 x_1_t0 = self.backward_loop( timesteps=timesteps[-t1 - 1 : -t0 - 1], prompt_embeds=prompt_embeds, latents=x_1_t1, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=0, ) # Propagate first frame latents at time T_0 to remaining frames x_2k_t0 = x_1_t0.repeat(video_length - 1, 1, 1, 1) # Add motion in latents at time T_0 x_2k_t0 = create_motion_field_and_warp_latents( motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=x_2k_t0, frame_ids=frame_ids[1:], ) # Perform forward process up to time T_1 x_2k_t1 = self.forward_loop( x_t0=x_2k_t0, t0=timesteps[-t0 - 1].item(), t1=timesteps[-t1 - 1].item(), generator=generator, ) # Perform backward process from time T_1 to 0 x_1k_t1 = torch.cat([x_1_t1, x_2k_t1]) b, l, d = prompt_embeds.size() prompt_embeds = prompt_embeds[:, None].repeat(1, video_length, 1, 1).reshape(b * video_length, l, d) self.scheduler = scheduler_copy x_1k_0 = self.backward_loop( timesteps=timesteps[-t1 - 1 :], prompt_embeds=prompt_embeds, latents=x_1k_t1, guidance_scale=guidance_scale, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=0, ) latents = x_1k_0 # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") torch.cuda.empty_cache() if output_type == "latent": image = latents has_nsfw_concept = None else: image = self.decode_latents(latents) # Run safety checker image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/diffusers/pipelines/unclip/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import UnCLIPImageVariationPipeline, UnCLIPPipeline else: from .pipeline_unclip import UnCLIPPipeline from .pipeline_unclip_image_variation import UnCLIPImageVariationPipeline from .text_proj import UnCLIPTextProjModel ================================================ FILE: 6DoF/diffusers/pipelines/unclip/pipeline_unclip.py ================================================ # Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Tuple, Union import torch from torch.nn import functional as F from transformers import CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...pipelines import DiffusionPipeline from ...pipelines.pipeline_utils import ImagePipelineOutput from ...schedulers import UnCLIPScheduler from ...utils import is_accelerate_available, logging, randn_tensor from .text_proj import UnCLIPTextProjModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class UnCLIPPipeline(DiffusionPipeline): """ Pipeline for text-to-image generation using unCLIP This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). prior ([`PriorTransformer`]): The canonincal unCLIP prior to approximate the image embedding from the text embedding. text_proj ([`UnCLIPTextProjModel`]): Utility class to prepare and combine the embeddings before they are passed to the decoder. decoder ([`UNet2DConditionModel`]): The decoder to invert the image embedding into an image. super_res_first ([`UNet2DModel`]): Super resolution unet. Used in all but the last step of the super resolution diffusion process. super_res_last ([`UNet2DModel`]): Super resolution unet. Used in the last step of the super resolution diffusion process. prior_scheduler ([`UnCLIPScheduler`]): Scheduler used in the prior denoising process. Just a modified DDPMScheduler. decoder_scheduler ([`UnCLIPScheduler`]): Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. super_res_scheduler ([`UnCLIPScheduler`]): Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. """ prior: PriorTransformer decoder: UNet2DConditionModel text_proj: UnCLIPTextProjModel text_encoder: CLIPTextModelWithProjection tokenizer: CLIPTokenizer super_res_first: UNet2DModel super_res_last: UNet2DModel prior_scheduler: UnCLIPScheduler decoder_scheduler: UnCLIPScheduler super_res_scheduler: UnCLIPScheduler def __init__( self, prior: PriorTransformer, decoder: UNet2DConditionModel, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_proj: UnCLIPTextProjModel, super_res_first: UNet2DModel, super_res_last: UNet2DModel, prior_scheduler: UnCLIPScheduler, decoder_scheduler: UnCLIPScheduler, super_res_scheduler: UnCLIPScheduler, ): super().__init__() self.register_modules( prior=prior, decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, text_proj=text_proj, super_res_first=super_res_first, super_res_last=super_res_last, prior_scheduler=prior_scheduler, decoder_scheduler=decoder_scheduler, super_res_scheduler=super_res_scheduler, ) def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None, ): if text_model_output is None: batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state else: batch_size = text_model_output[0].shape[0] prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1] text_mask = text_attention_mask prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list models = [ self.decoder, self.text_proj, self.text_encoder, self.super_res_first, self.super_res_last, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): return self.device for module in self.decoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, prior_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25, super_res_num_inference_steps: int = 7, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prior_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None, text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, text_attention_mask: Optional[torch.Tensor] = None, prior_guidance_scale: float = 4.0, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. This can only be left undefined if `text_model_output` and `text_attention_mask` is passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. prior_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps for the prior. More denoising steps usually lead to a higher quality image at the expense of slower inference. decoder_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality image at the expense of slower inference. super_res_num_inference_steps (`int`, *optional*, defaults to 7): The number of denoising steps for super resolution. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): Pre-generated noisy latents to be used as inputs for the prior. decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. prior_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. decoder_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. text_model_output (`CLIPTextModelOutput`, *optional*): Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs can be passed for tasks like text embedding interpolations. Make sure to also pass `text_attention_mask` in this case. `prompt` can the be left to `None`. text_attention_mask (`torch.Tensor`, *optional*): Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention masks are necessary when passing `text_model_output`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ if prompt is not None: if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") else: batch_size = text_model_output[0].shape[0] device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask ) # prior self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) prior_timesteps_tensor = self.prior_scheduler.timesteps embedding_dim = self.prior.config.embedding_dim prior_latents = self.prepare_latents( (batch_size, embedding_dim), prompt_embeds.dtype, device, generator, prior_latents, self.prior_scheduler, ) for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents predicted_image_embedding = self.prior( latent_model_input, timestep=t, proj_embedding=prompt_embeds, encoder_hidden_states=text_encoder_hidden_states, attention_mask=text_mask, ).predicted_image_embedding if do_classifier_free_guidance: predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( predicted_image_embedding_text - predicted_image_embedding_uncond ) if i + 1 == prior_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = prior_timesteps_tensor[i + 1] prior_latents = self.prior_scheduler.step( predicted_image_embedding, timestep=t, sample=prior_latents, generator=generator, prev_timestep=prev_timestep, ).prev_sample prior_latents = self.prior.post_process_latents(prior_latents) image_embeddings = prior_latents # done prior # decoder text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, prompt_embeds=prompt_embeds, text_encoder_hidden_states=text_encoder_hidden_states, do_classifier_free_guidance=do_classifier_free_guidance, ) if device.type == "mps": # HACK: MPS: There is a panic when padding bool tensors, # so cast to int tensor for the pad and back to bool afterwards text_mask = text_mask.type(torch.int) decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) decoder_text_mask = decoder_text_mask.type(torch.bool) else: decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps num_channels_latents = self.decoder.config.in_channels height = self.decoder.config.sample_size width = self.decoder.config.sample_size decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, decoder_latents, self.decoder_scheduler, ) for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents noise_pred = self.decoder( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, class_labels=additive_clip_time_embeddings, attention_mask=decoder_text_mask, ).sample if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if i + 1 == decoder_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = decoder_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 decoder_latents = self.decoder_scheduler.step( noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample decoder_latents = decoder_latents.clamp(-1, 1) image_small = decoder_latents # done decoder # super res self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps channels = self.super_res_first.config.in_channels // 2 height = self.super_res_first.config.sample_size width = self.super_res_first.config.sample_size super_res_latents = self.prepare_latents( (batch_size, channels, height, width), image_small.dtype, device, generator, super_res_latents, self.super_res_scheduler, ) if device.type == "mps": # MPS does not support many interpolations image_upscaled = F.interpolate(image_small, size=[height, width]) else: interpolate_antialias = {} if "antialias" in inspect.signature(F.interpolate).parameters: interpolate_antialias["antialias"] = True image_upscaled = F.interpolate( image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias ) for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): # no classifier free guidance if i == super_res_timesteps_tensor.shape[0] - 1: unet = self.super_res_last else: unet = self.super_res_first latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) noise_pred = unet( sample=latent_model_input, timestep=t, ).sample if i + 1 == super_res_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = super_res_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 super_res_latents = self.super_res_scheduler.step( noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample image = super_res_latents # done super res # post processing image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py ================================================ # Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import List, Optional, Union import PIL import torch from torch.nn import functional as F from transformers import ( CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...models import UNet2DConditionModel, UNet2DModel from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...schedulers import UnCLIPScheduler from ...utils import is_accelerate_available, logging, randn_tensor from .text_proj import UnCLIPTextProjModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class UnCLIPImageVariationPipeline(DiffusionPipeline): """ Pipeline to generate variations from an input image using unCLIP This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: text_encoder ([`CLIPTextModelWithProjection`]): Frozen text-encoder. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `image_encoder`. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_proj ([`UnCLIPTextProjModel`]): Utility class to prepare and combine the embeddings before they are passed to the decoder. decoder ([`UNet2DConditionModel`]): The decoder to invert the image embedding into an image. super_res_first ([`UNet2DModel`]): Super resolution unet. Used in all but the last step of the super resolution diffusion process. super_res_last ([`UNet2DModel`]): Super resolution unet. Used in the last step of the super resolution diffusion process. decoder_scheduler ([`UnCLIPScheduler`]): Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. super_res_scheduler ([`UnCLIPScheduler`]): Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. """ decoder: UNet2DConditionModel text_proj: UnCLIPTextProjModel text_encoder: CLIPTextModelWithProjection tokenizer: CLIPTokenizer feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection super_res_first: UNet2DModel super_res_last: UNet2DModel decoder_scheduler: UnCLIPScheduler super_res_scheduler: UnCLIPScheduler def __init__( self, decoder: UNet2DConditionModel, text_encoder: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, text_proj: UnCLIPTextProjModel, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, super_res_first: UNet2DModel, super_res_last: UNet2DModel, decoder_scheduler: UnCLIPScheduler, super_res_scheduler: UnCLIPScheduler, ): super().__init__() self.register_modules( decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, text_proj=text_proj, feature_extractor=feature_extractor, image_encoder=image_encoder, super_res_first=super_res_first, super_res_last=super_res_last, decoder_scheduler=decoder_scheduler, super_res_scheduler=super_res_scheduler, ) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) latents = latents * scheduler.init_noise_sigma return latents def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) text_input_ids = text_inputs.input_ids text_mask = text_inputs.attention_mask.bool().to(device) text_encoder_output = self.text_encoder(text_input_ids.to(device)) prompt_embeds = text_encoder_output.text_embeds text_encoder_hidden_states = text_encoder_output.last_hidden_state prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: uncond_tokens = [""] * batch_size max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) uncond_text_mask = uncond_input.attention_mask.bool().to(device) negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) seq_len = uncond_text_encoder_hidden_states.shape[1] uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( batch_size * num_images_per_prompt, seq_len, -1 ) uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) # done duplicates # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) text_mask = torch.cat([uncond_text_mask, text_mask]) return prompt_embeds, text_encoder_hidden_states, text_mask def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None): dtype = next(self.image_encoder.parameters()).dtype if image_embeddings is None: if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0) return image_embeddings def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") models = [ self.decoder, self.text_proj, self.text_encoder, self.super_res_first, self.super_res_last, ] for cpu_offloaded_model in models: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): return self.device for module in self.decoder.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device @torch.no_grad() def __call__( self, image: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]] = None, num_images_per_prompt: int = 1, decoder_num_inference_steps: int = 25, super_res_num_inference_steps: int = 7, generator: Optional[torch.Generator] = None, decoder_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None, image_embeddings: Optional[torch.Tensor] = None, decoder_guidance_scale: float = 8.0, output_type: Optional[str] = "pil", return_dict: bool = True, ): """ Function invoked when calling the pipeline for generation. Args: image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): The image or images to guide the image generation. If you provide a tensor, it needs to comply with the configuration of [this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) `CLIPImageProcessor`. Can be left to `None` only when `image_embeddings` are passed. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. decoder_num_inference_steps (`int`, *optional*, defaults to 25): The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality image at the expense of slower inference. super_res_num_inference_steps (`int`, *optional*, defaults to 7): The number of denoising steps for super resolution. More denoising steps usually lead to a higher quality image at the expense of slower inference. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): Pre-generated noisy latents to be used as inputs for the decoder. decoder_guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. image_embeddings (`torch.Tensor`, *optional*): Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings can be passed for tasks like image interpolations. `image` can the be left to `None`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. """ if image is not None: if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) else: batch_size = image.shape[0] else: batch_size = image_embeddings.shape[0] prompt = [""] * batch_size device = self._execution_device batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = decoder_guidance_scale > 1.0 prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance ) image_embeddings = self._encode_image(image, device, num_images_per_prompt, image_embeddings) # decoder text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( image_embeddings=image_embeddings, prompt_embeds=prompt_embeds, text_encoder_hidden_states=text_encoder_hidden_states, do_classifier_free_guidance=do_classifier_free_guidance, ) if device.type == "mps": # HACK: MPS: There is a panic when padding bool tensors, # so cast to int tensor for the pad and back to bool afterwards text_mask = text_mask.type(torch.int) decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) decoder_text_mask = decoder_text_mask.type(torch.bool) else: decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) decoder_timesteps_tensor = self.decoder_scheduler.timesteps num_channels_latents = self.decoder.config.in_channels height = self.decoder.config.sample_size width = self.decoder.config.sample_size if decoder_latents is None: decoder_latents = self.prepare_latents( (batch_size, num_channels_latents, height, width), text_encoder_hidden_states.dtype, device, generator, decoder_latents, self.decoder_scheduler, ) for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents noise_pred = self.decoder( sample=latent_model_input, timestep=t, encoder_hidden_states=text_encoder_hidden_states, class_labels=additive_clip_time_embeddings, attention_mask=decoder_text_mask, ).sample if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) if i + 1 == decoder_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = decoder_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 decoder_latents = self.decoder_scheduler.step( noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample decoder_latents = decoder_latents.clamp(-1, 1) image_small = decoder_latents # done decoder # super res self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) super_res_timesteps_tensor = self.super_res_scheduler.timesteps channels = self.super_res_first.config.in_channels // 2 height = self.super_res_first.config.sample_size width = self.super_res_first.config.sample_size if super_res_latents is None: super_res_latents = self.prepare_latents( (batch_size, channels, height, width), image_small.dtype, device, generator, super_res_latents, self.super_res_scheduler, ) if device.type == "mps": # MPS does not support many interpolations image_upscaled = F.interpolate(image_small, size=[height, width]) else: interpolate_antialias = {} if "antialias" in inspect.signature(F.interpolate).parameters: interpolate_antialias["antialias"] = True image_upscaled = F.interpolate( image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias ) for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): # no classifier free guidance if i == super_res_timesteps_tensor.shape[0] - 1: unet = self.super_res_last else: unet = self.super_res_first latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) noise_pred = unet( sample=latent_model_input, timestep=t, ).sample if i + 1 == super_res_timesteps_tensor.shape[0]: prev_timestep = None else: prev_timestep = super_res_timesteps_tensor[i + 1] # compute the previous noisy sample x_t -> x_t-1 super_res_latents = self.super_res_scheduler.step( noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample image = super_res_latents # done super res # post processing image = image * 0.5 + 0.5 image = image.clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/unclip/text_proj.py ================================================ # Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin class UnCLIPTextProjModel(ModelMixin, ConfigMixin): """ Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the decoder. For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1 """ @register_to_config def __init__( self, *, clip_extra_context_tokens: int = 4, clip_embeddings_dim: int = 768, time_embed_dim: int, cross_attention_dim, ): super().__init__() self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim)) # parameters for additional clip time embeddings self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim) self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim) # parameters for encoder hidden states self.clip_extra_context_tokens = clip_extra_context_tokens self.clip_extra_context_tokens_proj = nn.Linear( clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim ) self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim) self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim) def forward(self, *, image_embeddings, prompt_embeds, text_encoder_hidden_states, do_classifier_free_guidance): if do_classifier_free_guidance: # Add the classifier free guidance embeddings to the image embeddings image_embeddings_batch_size = image_embeddings.shape[0] classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0) classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand( image_embeddings_batch_size, -1 ) image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0) # The image embeddings batch size and the text embeddings batch size are equal assert image_embeddings.shape[0] == prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0] # "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and # adding CLIP embeddings to the existing timestep embedding, ... time_projected_prompt_embeds = self.embedding_proj(prompt_embeds) time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings) additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_prompt_embeds # ... and by projecting CLIP embeddings into four # extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder" clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings) clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens) clip_extra_context_tokens = clip_extra_context_tokens.permute(0, 2, 1) text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states) text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states) text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=1) return text_encoder_hidden_states, additive_clip_time_embeddings ================================================ FILE: 6DoF/diffusers/pipelines/unidiffuser/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( ImageTextPipelineOutput, UniDiffuserPipeline, ) else: from .modeling_text_decoder import UniDiffuserTextDecoder from .modeling_uvit import UniDiffuserModel, UTransformer2DModel from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline ================================================ FILE: 6DoF/diffusers/pipelines/unidiffuser/modeling_text_decoder.py ================================================ from typing import Optional import numpy as np import torch from torch import nn from transformers import GPT2Config, GPT2LMHeadModel from transformers.modeling_utils import ModuleUtilsMixin from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin # Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): """ Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to generate text from the UniDiffuser image-text embedding. Parameters: prefix_length (`int`): Max number of prefix tokens that will be supplied to the model. prefix_inner_dim (`int`): The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the CLIP text encoder. prefix_hidden_dim (`int`, *optional*): Hidden dim of the MLP if we encode the prefix. vocab_size (`int`, *optional*, defaults to 50257): Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. n_positions (`int`, *optional*, defaults to 1024): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). n_embd (`int`, *optional*, defaults to 768): Dimensionality of the embeddings and hidden states. n_layer (`int`, *optional*, defaults to 12): Number of hidden layers in the Transformer encoder. n_head (`int`, *optional*, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. n_inner (`int`, *optional*, defaults to None): Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd activation_function (`str`, *optional*, defaults to `"gelu"`): Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. resid_pdrop (`float`, *optional*, defaults to 0.1): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. embd_pdrop (`float`, *optional*, defaults to 0.1): The dropout ratio for the embeddings. attn_pdrop (`float`, *optional*, defaults to 0.1): The dropout ratio for the attention. layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. scale_attn_weights (`bool`, *optional*, defaults to `True`): Scale attention weights by dividing by sqrt(hidden_size).. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): Whether to additionally scale attention weights by `1 / layer_idx + 1`. reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention dot-product/softmax to float() when training with mixed precision. """ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] @register_to_config def __init__( self, prefix_length: int, prefix_inner_dim: int, prefix_hidden_dim: Optional[int] = None, vocab_size: int = 50257, # Start of GPT2 config args n_positions: int = 1024, n_embd: int = 768, n_layer: int = 12, n_head: int = 12, n_inner: Optional[int] = None, activation_function: str = "gelu_new", resid_pdrop: float = 0.1, embd_pdrop: float = 0.1, attn_pdrop: float = 0.1, layer_norm_epsilon: float = 1e-5, initializer_range: float = 0.02, scale_attn_weights: bool = True, use_cache: bool = True, scale_attn_by_inverse_layer_idx: bool = False, reorder_and_upcast_attn: bool = False, ): super().__init__() self.prefix_length = prefix_length if prefix_inner_dim != n_embd and prefix_hidden_dim is None: raise ValueError( f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and" f" `n_embd`: {n_embd} are not equal." ) self.prefix_inner_dim = prefix_inner_dim self.prefix_hidden_dim = prefix_hidden_dim self.encode_prefix = ( nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity() ) self.decode_prefix = ( nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() ) gpt_config = GPT2Config( vocab_size=vocab_size, n_positions=n_positions, n_embd=n_embd, n_layer=n_layer, n_head=n_head, n_inner=n_inner, activation_function=activation_function, resid_pdrop=resid_pdrop, embd_pdrop=embd_pdrop, attn_pdrop=attn_pdrop, layer_norm_epsilon=layer_norm_epsilon, initializer_range=initializer_range, scale_attn_weights=scale_attn_weights, use_cache=use_cache, scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, reorder_and_upcast_attn=reorder_and_upcast_attn, ) self.transformer = GPT2LMHeadModel(gpt_config) def forward( self, input_ids: torch.Tensor, prefix_embeds: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ): """ Args: input_ids (`torch.Tensor` of shape `(N, max_seq_len)`): Text tokens to use for inference. prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`): Prefix embedding to preprend to the embedded tokens. attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): Attention mask for the prefix embedding. labels (`torch.Tensor`, *optional*): Labels to use for language modeling. """ embedding_text = self.transformer.transformer.wte(input_ids) hidden = self.encode_prefix(prefix_embeds) prefix_embeds = self.decode_prefix(hidden) embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1) if labels is not None: dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) labels = torch.cat((dummy_token, input_ids), dim=1) out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) if self.prefix_hidden_dim is not None: return out, hidden else: return out def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) def encode(self, prefix): return self.encode_prefix(prefix) @torch.no_grad() def generate_captions(self, features, eos_token_id, device): """ Generate captions given text embedding features. Returns list[L]. Args: features (`torch.Tensor` of shape `(B, L, D)`): Text embedding features to generate captions from. eos_token_id (`int`): The token ID of the EOS token for the text decoder model. device: Device to perform text generation on. Returns: `List[str]`: A list of strings generated from the decoder model. """ features = torch.split(features, 1, dim=0) generated_tokens = [] generated_seq_lengths = [] for feature in features: feature = self.decode_prefix(feature.to(device)) # back to the clip feature # Only support beam search for now output_tokens, seq_lengths = self.generate_beam( input_embeds=feature, device=device, eos_token_id=eos_token_id ) generated_tokens.append(output_tokens[0]) generated_seq_lengths.append(seq_lengths[0]) generated_tokens = torch.stack(generated_tokens) generated_seq_lengths = torch.stack(generated_seq_lengths) return generated_tokens, generated_seq_lengths @torch.no_grad() def generate_beam( self, input_ids=None, input_embeds=None, device=None, beam_size: int = 5, entry_length: int = 67, temperature: float = 1.0, eos_token_id: Optional[int] = None, ): """ Generates text using the given tokenizer and text prompt or token embedding via beam search. This implementation is based on the beam search implementation from the [original UniDiffuser code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). Args: eos_token_id (`int`, *optional*): The token ID of the EOS token for the text decoder model. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` must be supplied. input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): An embedded representation to directly pass to the transformer as a prefix for beam search. One of `input_ids` and `input_embeds` must be supplied. device: The device to perform beam search on. beam_size (`int`, *optional*, defaults to `5`): The number of best states to store during beam search. entry_length (`int`, *optional*, defaults to `67`): The number of iterations to run beam search. temperature (`float`, *optional*, defaults to 1.0): The temperature to use when performing the softmax over logits from the decoding model. Returns: `Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated token sequences sorted by score in descending order, and the second element is the sequence lengths corresponding to those sequences. """ # Generates text until stop_token is reached using beam search with the desired beam size. stop_token_index = eos_token_id tokens = None scores = None seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int) is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) if input_embeds is not None: generated = input_embeds else: generated = self.transformer.transformer.wte(input_ids) for i in range(entry_length): outputs = self.transformer(inputs_embeds=generated) logits = outputs.logits logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) logits = logits.softmax(-1).log() if scores is None: scores, next_tokens = logits.topk(beam_size, -1) generated = generated.expand(beam_size, *generated.shape[1:]) next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) if tokens is None: tokens = next_tokens else: tokens = tokens.expand(beam_size, *tokens.shape[1:]) tokens = torch.cat((tokens, next_tokens), dim=1) else: logits[is_stopped] = -float(np.inf) logits[is_stopped, 0] = 0 scores_sum = scores[:, None] + logits seq_lengths[~is_stopped] += 1 scores_sum_average = scores_sum / seq_lengths[:, None] scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) next_tokens_source = next_tokens // scores_sum.shape[1] seq_lengths = seq_lengths[next_tokens_source] next_tokens = next_tokens % scores_sum.shape[1] next_tokens = next_tokens.unsqueeze(1) tokens = tokens[next_tokens_source] tokens = torch.cat((tokens, next_tokens), dim=1) generated = generated[next_tokens_source] scores = scores_sum_average * seq_lengths is_stopped = is_stopped[next_tokens_source] next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) generated = torch.cat((generated, next_token_embed), dim=1) is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() if is_stopped.all(): break scores = scores / seq_lengths order = scores.argsort(descending=True) # tokens tensors are already padded to max_seq_length output_texts = [tokens[i] for i in order] output_texts = torch.stack(output_texts, dim=0) seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype) return output_texts, seq_lengths ================================================ FILE: 6DoF/diffusers/pipelines/unidiffuser/modeling_uvit.py ================================================ import math from typing import Optional, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.attention import AdaLayerNorm, FeedForward from ...models.attention_processor import Attention from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed from ...models.transformer_2d import Transformer2DModelOutput from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): logger.warning( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect." ) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): # type: (torch.Tensor, float, float, float, float) -> torch.Tensor r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) class PatchEmbed(nn.Module): """2D Image to Patch Embedding""" def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, use_pos_embed=True, ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None self.use_pos_embed = use_pos_embed if self.use_pos_embed: pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) def forward(self, latent): latent = self.proj(latent) if self.flatten: latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: latent = self.norm(latent) if self.use_pos_embed: return latent + self.pos_embed else: return latent class SkipBlock(nn.Module): def __init__(self, dim: int): super().__init__() self.skip_linear = nn.Linear(2 * dim, dim) # Use torch.nn.LayerNorm for now, following the original code self.norm = nn.LayerNorm(dim) def forward(self, x, skip): x = self.skip_linear(torch.cat([x, skip], dim=-1)) x = self.norm(x) return x # Modified to support both pre-LayerNorm and post-LayerNorm configurations # Don't support AdaLayerNormZero for now # Modified from diffusers.models.attention.BasicTransformerBlock class UTransformerBlock(nn.Module): r""" A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (:obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (:obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float32 when performing the attention calculation. norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. norm_type (`str`, defaults to `"layer_norm"`): The layer norm implementation to use. pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g. `pre_layer_norm = True`. final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", pre_layer_norm: bool = True, final_dropout: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.pre_layer_norm = pre_layer_norm if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # 1. Self-Attn self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.attn2 = None if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) else: self.norm2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) def forward( self, hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, timestep=None, cross_attention_kwargs=None, class_labels=None, ): # Pre-LayerNorm if self.pre_layer_norm: if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) else: norm_hidden_states = self.norm1(hidden_states) else: norm_hidden_states = hidden_states # 1. Self-Attention cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) # Post-LayerNorm if not self.pre_layer_norm: if self.use_ada_layer_norm: attn_output = self.norm1(attn_output, timestep) else: attn_output = self.norm1(attn_output) hidden_states = attn_output + hidden_states if self.attn2 is not None: # Pre-LayerNorm if self.pre_layer_norm: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) else: norm_hidden_states = hidden_states # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly # prepare attention mask here # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) # Post-LayerNorm if not self.pre_layer_norm: attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output) hidden_states = attn_output + hidden_states # 3. Feed-forward # Pre-LayerNorm if self.pre_layer_norm: norm_hidden_states = self.norm3(hidden_states) else: norm_hidden_states = hidden_states ff_output = self.ff(norm_hidden_states) # Post-LayerNorm if not self.pre_layer_norm: ff_output = self.norm3(ff_output) hidden_states = ff_output + hidden_states return hidden_states # Like UTransformerBlock except with LayerNorms on the residual backbone of the block # Modified from diffusers.models.attention.BasicTransformerBlock class UniDiffuserBlock(nn.Module): r""" A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104). Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (:obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (:obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used. double_self_attention (`bool`, *optional*): Whether to use two self-attention layers. In this case no cross attention layers are used. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float() when performing the attention calculation. norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. norm_type (`str`, defaults to `"layer_norm"`): The layer norm implementation to use. pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm (`pre_layer_norm = False`). final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", pre_layer_norm: bool = False, final_dropout: bool = True, ): super().__init__() self.only_cross_attention = only_cross_attention self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.pre_layer_norm = pre_layer_norm if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." ) # 1. Self-Attn self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, ) # 2. Cross-Attn if cross_attention_dim is not None or double_self_attention: self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, ) # is self-attn if encoder_hidden_states is none else: self.attn2 = None if self.use_ada_layer_norm: self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) if cross_attention_dim is not None or double_self_attention: # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. self.norm2 = ( AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) ) else: self.norm2 = None # 3. Feed-forward self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) def forward( self, hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, timestep=None, cross_attention_kwargs=None, class_labels=None, ): # Following the diffusers transformer block implementation, put the LayerNorm on the # residual backbone # Pre-LayerNorm if self.pre_layer_norm: if self.use_ada_layer_norm: hidden_states = self.norm1(hidden_states, timestep) else: hidden_states = self.norm1(hidden_states) # 1. Self-Attention cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # Following the diffusers transformer block implementation, put the LayerNorm on the # residual backbone # Post-LayerNorm if not self.pre_layer_norm: if self.use_ada_layer_norm: hidden_states = self.norm1(hidden_states, timestep) else: hidden_states = self.norm1(hidden_states) if self.attn2 is not None: # Pre-LayerNorm if self.pre_layer_norm: hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly # prepare attention mask here # 2. Cross-Attention attn_output = self.attn2( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) hidden_states = attn_output + hidden_states # Post-LayerNorm if not self.pre_layer_norm: hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) # 3. Feed-forward # Pre-LayerNorm if self.pre_layer_norm: hidden_states = self.norm3(hidden_states) ff_output = self.ff(hidden_states) hidden_states = ff_output + hidden_states # Post-LayerNorm if not self.pre_layer_norm: hidden_states = self.norm3(hidden_states) return hidden_states # Modified from diffusers.models.transformer_2d.Transformer2DModel # Modify the transformer block structure to be U-Net like following U-ViT # Only supports patch-style input and torch.nn.LayerNorm currently # https://github.com/baofff/U-ViT class UTransformer2DModel(ModelMixin, ConfigMixin): """ Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion, similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`] layer and then reshaped to (b, t, d). Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input. out_channels (`int`, *optional*): The number of output channels; if `None`, defaults to `in_channels`. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups to use when performing Group Normalization. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. patch_size (`int`, *optional*, defaults to 2): The patch size to use in the patch embedding. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. use_linear_projection (int, *optional*): TODO: Not used only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used in each transformer block. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float() when performing the attention calculation. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. block_type (`str`, *optional*, defaults to `"unidiffuser"`): The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard behavior in `diffusers`.) pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm (`pre_layer_norm = False`). norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. use_patch_pos_embed (`bool`, *optional*): Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. """ @register_to_config def __init__( self, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = 2, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", block_type: str = "unidiffuser", pre_layer_norm: bool = False, norm_elementwise_affine: bool = True, use_patch_pos_embed=False, ff_final_dropout: bool = False, ): super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim # 1. Input # Only support patch input of shape (batch_size, num_channels, height, width) for now assert in_channels is not None and patch_size is not None, "Patch input requires in_channels and patch_size." assert sample_size is not None, "UTransformer2DModel over patched input must provide sample_size" # 2. Define input layers self.height = sample_size self.width = sample_size self.patch_size = patch_size self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=inner_dim, use_pos_embed=use_patch_pos_embed, ) # 3. Define transformers blocks # Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block, # and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.). # Quick hack to make the transformer block type configurable if block_type == "unidiffuser": block_cls = UniDiffuserBlock else: block_cls = UTransformerBlock self.transformer_in_blocks = nn.ModuleList( [ block_cls( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, final_dropout=ff_final_dropout, ) for d in range(num_layers // 2) ] ) self.transformer_mid_block = block_cls( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, final_dropout=ff_final_dropout, ) # For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs # before each transformer out_block. self.transformer_out_blocks = nn.ModuleList( [ nn.ModuleDict( { "skip": SkipBlock( inner_dim, ), "block": block_cls( inner_dim, num_attention_heads, attention_head_dim, dropout=dropout, cross_attention_dim=cross_attention_dim, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, final_dropout=ff_final_dropout, ), } ) for d in range(num_layers // 2) ] ) # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels # Following the UniDiffuser U-ViT implementation, we process the transformer output with # a LayerNorm layer with per-element affine params self.norm_out = nn.LayerNorm(inner_dim) def forward( self, hidden_states, encoder_hidden_states=None, timestep=None, class_labels=None, cross_attention_kwargs=None, return_dict: bool = True, hidden_states_is_embedding: bool = False, unpatchify: bool = True, ): """ Args: hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input hidden_states encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels conditioning. cross_attention_kwargs (*optional*): Keyword arguments to supply to the cross attention layers, if used. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the transformer blocks. unpatchify (`bool`, *optional*, defaults to `True`): Whether to unpatchify the transformer output. Returns: [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # 0. Check inputs if not unpatchify and return_dict: raise ValueError( f"Cannot both define `unpatchify`: {unpatchify} and `return_dict`: {return_dict} since when" f" `unpatchify` is {unpatchify} the returned output is of shape (batch_size, seq_len, hidden_dim)" " rather than (batch_size, num_channels, height, width)." ) # 1. Input if not hidden_states_is_embedding: hidden_states = self.pos_embed(hidden_states) # 2. Blocks # In ("downsample") blocks skips = [] for in_block in self.transformer_in_blocks: hidden_states = in_block( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) skips.append(hidden_states) # Mid block hidden_states = self.transformer_mid_block(hidden_states) # Out ("upsample") blocks for out_block in self.transformer_out_blocks: hidden_states = out_block["skip"](hidden_states, skips.pop()) hidden_states = out_block["block"]( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, ) # 3. Output # Don't support AdaLayerNorm for now, so no conditioning/scale/shift logic hidden_states = self.norm_out(hidden_states) # hidden_states = self.proj_out(hidden_states) if unpatchify: # unpatchify height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) else: output = hidden_states if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) class UniDiffuserModel(ModelMixin, ConfigMixin): """ Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details). Parameters: text_dim (`int`): The hidden dimension of the CLIP text model used to embed images. clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts. num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): Pass if the input is continuous. The number of channels in the input. out_channels (`int`, *optional*): The number of output channels; if `None`, defaults to `in_channels`. num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups to use when performing Group Normalization. cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. attention_bias (`bool`, *optional*): Configure if the TransformerBlocks' attention should contain a bias parameter. sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. Note that this is fixed at training time as it is used for learning a number of position embeddings. See `ImagePositionalEmbeddings`. num_vector_embeds (`int`, *optional*): Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. patch_size (`int`, *optional*, defaults to 2): The patch size to use in the patch embedding. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. The number of diffusion steps used during training. Note that this is fixed at training time as it is used to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for up to but not more than steps than `num_embeds_ada_norm`. use_linear_projection (int, *optional*): TODO: Not used only_cross_attention (`bool`, *optional*): Whether to use only cross-attention layers. In this case two cross attention layers are used in each transformer block. upcast_attention (`bool`, *optional*): Whether to upcast the query and key to float32 when performing the attention calculation. norm_type (`str`, *optional*, defaults to `"layer_norm"`): The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. block_type (`str`, *optional*, defaults to `"unidiffuser"`): The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard behavior in `diffusers`.) pre_layer_norm (`bool`, *optional*): Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm (`pre_layer_norm = False`). norm_elementwise_affine (`bool`, *optional*): Whether to use learnable per-element affine parameters during layer normalization. use_patch_pos_embed (`bool`, *optional*): Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). ff_final_dropout (`bool`, *optional*): Whether to use a final Dropout layer after the feedforward network. use_data_type_embedding (`bool`, *optional*): Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1 is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type` argument, which can either be `1` to use the weights trained on non-publically-available data or `0` otherwise. This argument is subsequently embedded by the data type embedding, if used. """ @register_to_config def __init__( self, text_dim: int = 768, clip_img_dim: int = 512, num_text_tokens: int = 77, num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", block_type: str = "unidiffuser", pre_layer_norm: bool = False, use_timestep_embedding=False, norm_elementwise_affine: bool = True, use_patch_pos_embed=False, ff_final_dropout: bool = True, use_data_type_embedding: bool = False, ): super().__init__() # 0. Handle dimensions self.inner_dim = num_attention_heads * attention_head_dim assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size" self.sample_size = sample_size self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.patch_size = patch_size # Assume image is square... self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size) # 1. Define input layers # 1.1 Input layers for text and image input # For now, only support patch input for VAE latent image input self.vae_img_in = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=self.inner_dim, use_pos_embed=use_patch_pos_embed, ) self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim) self.text_in = nn.Linear(text_dim, self.inner_dim) # 1.2. Timestep embeddings for t_img, t_text self.timestep_img_proj = Timesteps( self.inner_dim, flip_sin_to_cos=True, downscale_freq_shift=0, ) self.timestep_img_embed = ( TimestepEmbedding( self.inner_dim, 4 * self.inner_dim, out_dim=self.inner_dim, ) if use_timestep_embedding else nn.Identity() ) self.timestep_text_proj = Timesteps( self.inner_dim, flip_sin_to_cos=True, downscale_freq_shift=0, ) self.timestep_text_embed = ( TimestepEmbedding( self.inner_dim, 4 * self.inner_dim, out_dim=self.inner_dim, ) if use_timestep_embedding else nn.Identity() ) # 1.3. Positional embedding self.num_text_tokens = num_text_tokens self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim)) self.pos_embed_drop = nn.Dropout(p=dropout) trunc_normal_(self.pos_embed, std=0.02) # 1.4. Handle data type token embeddings for UniDiffuser-V1, if necessary self.use_data_type_embedding = use_data_type_embedding if self.use_data_type_embedding: self.data_type_token_embedding = nn.Embedding(2, self.inner_dim) self.data_type_pos_embed_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim)) # 2. Define transformer blocks self.transformer = UTransformer2DModel( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, in_channels=in_channels, out_channels=out_channels, num_layers=num_layers, dropout=dropout, norm_num_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attention_bias=attention_bias, sample_size=sample_size, num_vector_embeds=num_vector_embeds, patch_size=patch_size, activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, norm_type=norm_type, block_type=block_type, pre_layer_norm=pre_layer_norm, norm_elementwise_affine=norm_elementwise_affine, use_patch_pos_embed=use_patch_pos_embed, ff_final_dropout=ff_final_dropout, ) # 3. Define output layers patch_dim = (patch_size**2) * out_channels self.vae_img_out = nn.Linear(self.inner_dim, patch_dim) self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim) self.text_out = nn.Linear(self.inner_dim, text_dim) @torch.jit.ignore def no_weight_decay(self): return {"pos_embed"} def forward( self, latent_image_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor, prompt_embeds: torch.FloatTensor, timestep_img: Union[torch.Tensor, float, int], timestep_text: Union[torch.Tensor, float, int], data_type: Optional[Union[torch.Tensor, float, int]] = 1, encoder_hidden_states=None, cross_attention_kwargs=None, ): """ Args: latent_image_embeds (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`): Latent image representation from the VAE encoder. image_embeds (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`): CLIP-embedded image representation (unsqueezed in the first dimension). prompt_embeds (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`): CLIP-embedded text representation. timestep_img (`torch.long` or `float` or `int`): Current denoising step for the image. timestep_text (`torch.long` or `float` or `int`): Current denoising step for the text. data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`): Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data, or `0` otherwise. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. cross_attention_kwargs (*optional*): Keyword arguments to supply to the cross attention layers, if used. Returns: `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text embedding. """ batch_size = latent_image_embeds.shape[0] # 1. Input # 1.1. Map inputs to shape (B, N, inner_dim) vae_hidden_states = self.vae_img_in(latent_image_embeds) clip_hidden_states = self.clip_img_in(image_embeds) text_hidden_states = self.text_in(prompt_embeds) num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1) # 1.2. Encode image timesteps to single token (B, 1, inner_dim) if not torch.is_tensor(timestep_img): timestep_img = torch.tensor([timestep_img], dtype=torch.long, device=vae_hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep_img = timestep_img * torch.ones(batch_size, dtype=timestep_img.dtype, device=timestep_img.device) timestep_img_token = self.timestep_img_proj(timestep_img) # t_img_token does not contain any weights and will always return f32 tensors # but time_embedding might be fp16, so we need to cast here. timestep_img_token = timestep_img_token.to(dtype=self.dtype) timestep_img_token = self.timestep_img_embed(timestep_img_token) timestep_img_token = timestep_img_token.unsqueeze(dim=1) # 1.3. Encode text timesteps to single token (B, 1, inner_dim) if not torch.is_tensor(timestep_text): timestep_text = torch.tensor([timestep_text], dtype=torch.long, device=vae_hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep_text = timestep_text * torch.ones(batch_size, dtype=timestep_text.dtype, device=timestep_text.device) timestep_text_token = self.timestep_text_proj(timestep_text) # t_text_token does not contain any weights and will always return f32 tensors # but time_embedding might be fp16, so we need to cast here. timestep_text_token = timestep_text_token.to(dtype=self.dtype) timestep_text_token = self.timestep_text_embed(timestep_text_token) timestep_text_token = timestep_text_token.unsqueeze(dim=1) # 1.4. Concatenate all of the embeddings together. if self.use_data_type_embedding: assert data_type is not None, "data_type must be supplied if the model uses a data type embedding" if not torch.is_tensor(data_type): data_type = torch.tensor([data_type], dtype=torch.int, device=vae_hidden_states.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML data_type = data_type * torch.ones(batch_size, dtype=data_type.dtype, device=data_type.device) data_type_token = self.data_type_token_embedding(data_type).unsqueeze(dim=1) hidden_states = torch.cat( [ timestep_img_token, timestep_text_token, data_type_token, text_hidden_states, clip_hidden_states, vae_hidden_states, ], dim=1, ) else: hidden_states = torch.cat( [timestep_img_token, timestep_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states], dim=1, ) # 1.5. Prepare the positional embeddings and add to hidden states # Note: I think img_vae should always have the proper shape, so there's no need to interpolate # the position embeddings. if self.use_data_type_embedding: pos_embed = torch.cat( [self.pos_embed[:, : 1 + 1, :], self.data_type_pos_embed_token, self.pos_embed[:, 1 + 1 :, :]], dim=1 ) else: pos_embed = self.pos_embed hidden_states = hidden_states + pos_embed hidden_states = self.pos_embed_drop(hidden_states) # 2. Blocks hidden_states = self.transformer( hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=None, class_labels=None, cross_attention_kwargs=cross_attention_kwargs, return_dict=False, hidden_states_is_embedding=True, unpatchify=False, )[0] # 3. Output # Split out the predicted noise representation. if self.use_data_type_embedding: ( t_img_token_out, t_text_token_out, data_type_token_out, text_out, img_clip_out, img_vae_out, ) = hidden_states.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1) else: t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split( (1, 1, num_text_tokens, 1, num_img_tokens), dim=1 ) img_vae_out = self.vae_img_out(img_vae_out) # unpatchify height = width = int(img_vae_out.shape[1] ** 0.5) img_vae_out = img_vae_out.reshape( shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) ) img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out) img_vae_out = img_vae_out.reshape( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) img_clip_out = self.clip_img_out(img_clip_out) text_out = self.text_out(text_out) return img_vae_out, img_clip_out, text_out ================================================ FILE: 6DoF/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py ================================================ import inspect import warnings from dataclasses import dataclass from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from transformers import ( CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, GPT2Tokenizer, ) from ...models import AutoencoderKL from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor, ) from ...utils.outputs import BaseOutput from ..pipeline_utils import DiffusionPipeline from .modeling_text_decoder import UniDiffuserTextDecoder from .modeling_uvit import UniDiffuserModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): warnings.warn( "The preprocess method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image # New BaseOutput child class for joint image-text output @dataclass class ImageTextPipelineOutput(BaseOutput): """ Output class for joint image-text pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, num_channels)`. text (`List[str]` or `List[List[str]]`) List of generated text strings of length `batch_size` or a list of list of strings whose outer list has length `batch_size`. """ images: Optional[Union[List[PIL.Image.Image], np.ndarray]] text: Optional[Union[List[str], List[List[str]]]] class UniDiffuserPipeline(DiffusionPipeline): r""" Pipeline for a bimodal image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model, which supports unconditional text and image generation, text-conditioned image generation, image-conditioned text generation, and joint image-text generation. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. This is part of the UniDiffuser image representation, along with the CLIP vision encoding. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to encode text prompts. image_encoder ([`CLIPVisionModel`]): UniDiffuser uses the vision portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode images as part of its image representation, along with the VAE latent representation. image_processor ([`CLIPImageProcessor`]): CLIP image processor of class [CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor), used to preprocess the image before CLIP encoding it with `image_encoder`. clip_tokenizer ([`CLIPTokenizer`]): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which is used to tokenizer a prompt before encoding it with `text_encoder`. text_decoder ([`UniDiffuserTextDecoder`]): Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser embedding. text_tokenizer ([`GPT2Tokenizer`]): Tokenizer of class [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which is used along with the `text_decoder` to decode text for text generation. unet ([`UniDiffuserModel`]): UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a [`Transformer2DModel`] with U-Net-style skip connections between transformer layers. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler. """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, image_encoder: CLIPVisionModelWithProjection, image_processor: CLIPImageProcessor, clip_tokenizer: CLIPTokenizer, text_decoder: UniDiffuserTextDecoder, text_tokenizer: GPT2Tokenizer, unet: UniDiffuserModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() if text_encoder.config.hidden_size != text_decoder.prefix_inner_dim: raise ValueError( f"The text encoder hidden size and text decoder prefix inner dim must be the same, but" f" `text_encoder.config.hidden_size`: {text_encoder.config.hidden_size} and `text_decoder.prefix_inner_dim`: {text_decoder.prefix_inner_dim}" ) self.register_modules( vae=vae, text_encoder=text_encoder, image_encoder=image_encoder, image_processor=image_processor, clip_tokenizer=clip_tokenizer, text_decoder=text_decoder, text_tokenizer=text_tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.num_channels_latents = vae.config.latent_channels self.text_encoder_seq_len = text_encoder.config.max_position_embeddings self.text_encoder_hidden_size = text_encoder.config.hidden_size self.image_encoder_projection_dim = image_encoder.config.projection_dim self.unet_resolution = unet.config.sample_size self.text_intermediate_dim = self.text_encoder_hidden_size if self.text_decoder.prefix_hidden_dim is not None: self.text_intermediate_dim = self.text_decoder.prefix_hidden_dim self.mode = None # TODO: handle safety checking? self.safety_checker = None # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.image_encoder, self.text_decoder]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae, self.image_encoder, self.text_decoder]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents): r""" Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set mode will be used. """ prompt_available = (prompt is not None) or (prompt_embeds is not None) image_available = image is not None input_available = prompt_available or image_available prompt_latents_available = prompt_latents is not None vae_latents_available = vae_latents is not None clip_latents_available = clip_latents is not None full_latents_available = latents is not None image_latents_available = vae_latents_available and clip_latents_available all_indv_latents_available = prompt_latents_available and image_latents_available if self.mode is not None: # Preferentially use the mode set by the user mode = self.mode elif prompt_available: mode = "text2img" elif image_available: mode = "img2text" else: # Neither prompt nor image supplied, infer based on availability of latents if full_latents_available or all_indv_latents_available: mode = "joint" elif prompt_latents_available: mode = "text" elif image_latents_available: mode = "img" else: # No inputs or latents available mode = "joint" # Give warnings for ambiguous cases if self.mode is None and prompt_available and image_available: logger.warning( f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually," f" defaulting to mode '{mode}'." ) if self.mode is None and not input_available: if vae_latents_available != clip_latents_available: # Exactly one of vae_latents and clip_latents is supplied logger.warning( f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none" f" are expected to be supplied. Defaulting to mode '{mode}'." ) elif not prompt_latents_available and not vae_latents_available and not clip_latents_available: # No inputs or latents supplied logger.warning( f"No inputs or latents have been supplied, and mode has not been manually set," f" defaulting to mode '{mode}'." ) return mode # Functions to manually set the mode def set_text_mode(self): r"""Manually set the generation mode to unconditional ("marginal") text generation.""" self.mode = "text" def set_image_mode(self): r"""Manually set the generation mode to unconditional ("marginal") image generation.""" self.mode = "img" def set_text_to_image_mode(self): r"""Manually set the generation mode to text-conditioned image generation.""" self.mode = "text2img" def set_image_to_text_mode(self): r"""Manually set the generation mode to image-conditioned text generation.""" self.mode = "img2text" def set_joint_mode(self): r"""Manually set the generation mode to unconditional joint image-text generation.""" self.mode = "joint" def reset_mode(self): r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs.""" self.mode = None def _infer_batch_size( self, mode, prompt, prompt_embeds, image, num_images_per_prompt, num_prompts_per_image, latents, prompt_latents, vae_latents, clip_latents, ): r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`.""" if num_images_per_prompt is None: num_images_per_prompt = 1 if num_prompts_per_image is None: num_prompts_per_image = 1 assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer" assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer" if mode in ["text2img"]: if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: # Either prompt or prompt_embeds must be present for text2img. batch_size = prompt_embeds.shape[0] multiplier = num_images_per_prompt elif mode in ["img2text"]: if isinstance(image, PIL.Image.Image): batch_size = 1 else: # Image must be available and type either PIL.Image.Image or torch.FloatTensor. # Not currently supporting something like image_embeds. batch_size = image.shape[0] multiplier = num_prompts_per_image elif mode in ["img"]: if vae_latents is not None: batch_size = vae_latents.shape[0] elif clip_latents is not None: batch_size = clip_latents.shape[0] else: batch_size = 1 multiplier = num_images_per_prompt elif mode in ["text"]: if prompt_latents is not None: batch_size = prompt_latents.shape[0] else: batch_size = 1 multiplier = num_prompts_per_image elif mode in ["joint"]: if latents is not None: batch_size = latents.shape[0] elif prompt_latents is not None: batch_size = prompt_latents.shape[0] elif vae_latents is not None: batch_size = vae_latents.shape[0] elif clip_latents is not None: batch_size = clip_latents.shape[0] else: batch_size = 1 if num_images_per_prompt == num_prompts_per_image: multiplier = num_images_per_prompt else: multiplier = min(num_images_per_prompt, num_prompts_per_image) logger.warning( f"You are using mode `{mode}` and `num_images_per_prompt`: {num_images_per_prompt} and" f" num_prompts_per_image: {num_prompts_per_image} are not equal. Using batch size equal to" f" `min(num_images_per_prompt, num_prompts_per_image) = {batch_size}." ) return batch_size, multiplier # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # self.tokenizer => self.clip_tokenizer def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.clip_tokenizer( prompt, padding="max_length", max_length=self.clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.clip_tokenizer.batch_decode( untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.clip_tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents # Add num_prompts_per_image argument, sample from autoencoder moment distribution def encode_image_vae_latents( self, image, batch_size, num_prompts_per_image, dtype, device, do_classifier_free_guidance, generator=None, ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_prompts_per_image if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): image_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) * self.vae.config.scaling_factor for i in range(batch_size) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) # Scale image_latents by the VAE's scaling factor image_latents = image_latents * self.vae.config.scaling_factor if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if do_classifier_free_guidance: uncond_image_latents = torch.zeros_like(image_latents) image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) return image_latents def encode_image_clip_latents( self, image, batch_size, num_prompts_per_image, dtype, device, generator=None, ): # Map image to CLIP embedding. if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) preprocessed_image = self.image_processor.preprocess( image, return_tensors="pt", ) preprocessed_image = preprocessed_image.to(device=device, dtype=dtype) batch_size = batch_size * num_prompts_per_image if isinstance(generator, list): image_latents = [ self.image_encoder(**preprocessed_image[i : i + 1]).image_embeds for i in range(batch_size) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = self.image_encoder(**preprocessed_image).image_embeds if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) return image_latents # Note that the CLIP latents are not decoded for image generation. # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents # Rename: decode_latents -> decode_image_latents def decode_image_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_text_latents( self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None ): # Prepare latents for the CLIP embedded prompt. shape = (batch_size * num_images_per_prompt, seq_len, hidden_size) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: # latents is assumed to have shace (B, L, D) latents = latents.repeat(num_images_per_prompt, 1, 1) latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument. def prepare_image_vae_latents( self, batch_size, num_prompts_per_image, num_channels_latents, height, width, dtype, device, generator, latents=None, ): shape = ( batch_size * num_prompts_per_image, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: # latents is assumed to have shape (B, C, H, W) latents = latents.repeat(num_prompts_per_image, 1, 1, 1) latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_image_clip_latents( self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None ): # Prepare latents for the CLIP embedded image. shape = (batch_size * num_prompts_per_image, 1, clip_img_dim) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: # latents is assumed to have shape (B, L, D) latents = latents.repeat(num_prompts_per_image, 1, 1) latents = latents.to(device=device, dtype=dtype) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _split(self, x, height, width): r""" Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) and (B, 1, clip_img_dim) """ batch_size = x.shape[0] latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor img_vae_dim = self.num_channels_latents * latent_height * latent_width img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_projection_dim], dim=1) img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) return img_vae, img_clip def _combine(self, img_vae, img_clip): r""" Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1, clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim). """ img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) return torch.concat([img_vae, img_clip], dim=-1) def _split_joint(self, x, height, width): r""" Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae, img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is of shape (B, text_seq_len, text_dim). """ batch_size = x.shape[0] latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor img_vae_dim = self.num_channels_latents * latent_height * latent_width text_dim = self.text_encoder_seq_len * self.text_intermediate_dim img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1) img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim)) return img_vae, img_clip, text def _combine_joint(self, img_vae, img_clip, text): r""" Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img, clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B, C * H * W + L_img * clip_img_dim + L_text * text_dim). """ img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) text = torch.reshape(text, (text.shape[0], -1)) return torch.concat([img_vae, img_clip, text], dim=-1) def _get_noise_pred( self, mode, latents, t, prompt_embeds, img_vae, img_clip, max_timestep, data_type, guidance_scale, generator, device, height, width, ): r""" Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary. """ if mode == "joint": # Joint text-image generation img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width) img_vae_out, img_clip_out, text_out = self.unet( img_vae_latents, img_clip_latents, text_latents, timestep_img=t, timestep_text=t, data_type=data_type ) x_out = self._combine_joint(img_vae_out, img_clip_out, text_out) if guidance_scale <= 1.0: return x_out # Classifier-free guidance img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) _, _, text_out_uncond = self.unet( img_vae_T, img_clip_T, text_latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type ) img_vae_out_uncond, img_clip_out_uncond, _ = self.unet( img_vae_latents, img_clip_latents, text_T, timestep_img=t, timestep_text=max_timestep, data_type=data_type, ) x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond) return guidance_scale * x_out + (1.0 - guidance_scale) * x_out_uncond elif mode == "text2img": # Text-conditioned image generation img_vae_latents, img_clip_latents = self._split(latents, height, width) img_vae_out, img_clip_out, text_out = self.unet( img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=0, data_type=data_type ) img_out = self._combine(img_vae_out, img_clip_out) if guidance_scale <= 1.0: return img_out # Classifier-free guidance text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( img_vae_latents, img_clip_latents, text_T, timestep_img=t, timestep_text=max_timestep, data_type=data_type, ) img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond) return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond elif mode == "img2text": # Image-conditioned text generation img_vae_out, img_clip_out, text_out = self.unet( img_vae, img_clip, latents, timestep_img=0, timestep_text=t, data_type=data_type ) if guidance_scale <= 1.0: return text_out # Classifier-free guidance img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( img_vae_T, img_clip_T, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type ) return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond elif mode == "text": # Unconditional ("marginal") text generation (no CFG) img_vae_out, img_clip_out, text_out = self.unet( img_vae, img_clip, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type ) return text_out elif mode == "img": # Unconditional ("marginal") image generation (no CFG) img_vae_latents, img_clip_latents = self._split(latents, height, width) img_vae_out, img_clip_out, text_out = self.unet( img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=max_timestep, data_type=data_type, ) img_out = self._combine(img_vae_out, img_clip_out) return img_out def check_latents_shape(self, latents_name, latents, expected_shape): latents_shape = latents.shape expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension expected_shape_str = ", ".join(str(dim) for dim in expected_shape) if len(latents_shape) != expected_num_dims: raise ValueError( f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" f" {latents_shape} has {len(latents_shape)} dimensions." ) for i in range(1, expected_num_dims): if latents_shape[i] != expected_shape[i - 1]: raise ValueError( f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" f" {latents_shape} has {latents_shape[i]} != {expected_shape[i - 1]} at dimension {i}." ) def check_inputs( self, mode, prompt, image, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, latents=None, prompt_latents=None, vae_latents=None, clip_latents=None, ): # Check inputs before running the generative process. if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: raise ValueError( f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if mode == "text2img": if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if mode == "img2text": if image is None: raise ValueError("`img2text` mode requires an image to be provided.") # Check provided latents latent_height = height // self.vae_scale_factor latent_width = width // self.vae_scale_factor full_latents_available = latents is not None prompt_latents_available = prompt_latents is not None vae_latents_available = vae_latents is not None clip_latents_available = clip_latents is not None if full_latents_available: individual_latents_available = ( prompt_latents is not None or vae_latents is not None or clip_latents is not None ) if individual_latents_available: logger.warning( "You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and" " `clip_latents`. The value of `latents` will override the value of any individually supplied latents." ) # Check shape of full latents img_vae_dim = self.num_channels_latents * latent_height * latent_width text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size latents_dim = img_vae_dim + self.image_encoder_projection_dim + text_dim latents_expected_shape = (latents_dim,) self.check_latents_shape("latents", latents, latents_expected_shape) # Check individual latent shapes, if present if prompt_latents_available: prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size) self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape) if vae_latents_available: vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width) self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape) if clip_latents_available: clip_latents_expected_shape = (1, self.image_encoder_projection_dim) self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape) if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available: if vae_latents.shape[0] != clip_latents.shape[0]: raise ValueError( f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:" f" {vae_latents.shape[0]} != {clip_latents.shape[0]}." ) if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available: if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]: raise ValueError( f"All of `prompt_latents`, `vae_latents`, and `clip_latents` are supplied, but their batch" f" dimensions are not equal: {prompt_latents.shape[0]} != {vae_latents.shape[0]}" f" != {clip_latents.shape[0]}." ) @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, data_type: Optional[int] = 1, num_inference_steps: int = 50, guidance_scale: float = 8.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, num_prompts_per_image: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_latents: Optional[torch.FloatTensor] = None, vae_latents: Optional[torch.FloatTensor] = None, clip_latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead. Required for text-conditioned image generation (`text2img`) mode. image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): `Image`, or tensor representing an image batch. Required for image-conditioned text generation (`img2text`) mode. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. data_type (`int`, *optional*, defaults to 1): The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type embedding; this is added for compatibility with the UniDiffuser-v1 checkpoint. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 8.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. Note that the original [UniDiffuser paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of the guidance scale `w'`, which satisfies `w = w' + 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). Used in text-conditioned image generation (`text2img`) mode. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. num_prompts_per_image (`int`, *optional*, defaults to 1): The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for joint image-text generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. Note that this is assumed to be a full set of VAE, CLIP, and text latents, if supplied, this will override the value of `prompt_latents`, `vae_latents`, and `clip_latents`. prompt_latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for text generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. vae_latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. clip_latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. Used in text-conditioned image generation (`text2img`) mode. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. Used in text-conditioned image generation (`text2img`) mode. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.unidiffuser.ImageTextPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`: [`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images, and the second element is a list of generated texts. """ # 0. Default height and width to unet height = height or self.unet_resolution * self.vae_scale_factor width = width or self.unet_resolution * self.vae_scale_factor # 1. Check inputs # Recalculate mode for each call to the pipeline. mode = self._infer_mode(prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents) self.check_inputs( mode, prompt, image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, latents, prompt_latents, vae_latents, clip_latents, ) # 2. Define call parameters batch_size, multiplier = self._infer_batch_size( mode, prompt, prompt_embeds, image, num_images_per_prompt, num_prompts_per_image, latents, prompt_latents, vae_latents, clip_latents, ) device = self._execution_device reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img" # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. # Note that this differs from the formulation in the unidiffusers paper! # do_classifier_free_guidance = guidance_scale > 1.0 # check if scheduler is in sigmas space # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") # 3. Encode input prompt, if available; otherwise prepare text latents if latents is not None: # Overwrite individual latents vae_latents, clip_latents, prompt_latents = self._split_joint(latents, height, width) if mode in ["text2img"]: # 3.1. Encode input prompt, if available assert prompt is not None or prompt_embeds is not None prompt_embeds = self._encode_prompt( prompt=prompt, device=device, num_images_per_prompt=multiplier, do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) else: # 3.2. Prepare text latent variables, if input not available prompt_embeds = self.prepare_text_latents( batch_size=batch_size, num_images_per_prompt=multiplier, seq_len=self.text_encoder_seq_len, hidden_size=self.text_encoder_hidden_size, dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision device=device, generator=generator, latents=prompt_latents, ) if reduce_text_emb_dim: prompt_embeds = self.text_decoder.encode(prompt_embeds) # 4. Encode image, if available; otherwise prepare image latents if mode in ["img2text"]: # 4.1. Encode images, if available assert image is not None, "`img2text` requires a conditioning image" # Encode image using VAE image_vae = preprocess(image) height, width = image_vae.shape[-2:] image_vae_latents = self.encode_image_vae_latents( image=image_vae, batch_size=batch_size, num_prompts_per_image=multiplier, dtype=prompt_embeds.dtype, device=device, do_classifier_free_guidance=False, # Copied from InstructPix2Pix, don't use their version of CFG generator=generator, ) # Encode image using CLIP image_clip_latents = self.encode_image_clip_latents( image=image, batch_size=batch_size, num_prompts_per_image=multiplier, dtype=prompt_embeds.dtype, device=device, generator=generator, ) # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size) image_clip_latents = image_clip_latents.unsqueeze(1) else: # 4.2. Prepare image latent variables, if input not available # Prepare image VAE latents in latent space image_vae_latents = self.prepare_image_vae_latents( batch_size=batch_size, num_prompts_per_image=multiplier, num_channels_latents=self.num_channels_latents, height=height, width=width, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=vae_latents, ) # Prepare image CLIP latents image_clip_latents = self.prepare_image_clip_latents( batch_size=batch_size, num_prompts_per_image=multiplier, clip_img_dim=self.image_encoder_projection_dim, dtype=prompt_embeds.dtype, device=device, generator=generator, latents=clip_latents, ) # 5. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # max_timestep = timesteps[0] max_timestep = self.scheduler.config.num_train_timesteps # 6. Prepare latent variables if mode == "joint": latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds) elif mode in ["text2img", "img"]: latents = self._combine(image_vae_latents, image_clip_latents) elif mode in ["img2text", "text"]: latents = prompt_embeds # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) logger.debug(f"Scheduler extra step kwargs: {extra_step_kwargs}") # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # predict the noise residual # Also applies classifier-free guidance as described in the UniDiffuser paper noise_pred = self._get_noise_pred( mode, latents, t, prompt_embeds, image_vae_latents, image_clip_latents, max_timestep, data_type, guidance_scale, generator, device, height, width, ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 9. Post-processing gen_image = None gen_text = None if mode == "joint": image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) # Map latent VAE image back to pixel space gen_image = self.decode_image_latents(image_vae_latents) # Generate text using the text decoder output_token_list, seq_lengths = self.text_decoder.generate_captions( text_latents, self.text_tokenizer.eos_token_id, device=device ) output_list = output_token_list.cpu().numpy() gen_text = [ self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) for output, length in zip(output_list, seq_lengths) ] elif mode in ["text2img", "img"]: image_vae_latents, image_clip_latents = self._split(latents, height, width) gen_image = self.decode_image_latents(image_vae_latents) elif mode in ["img2text", "text"]: text_latents = latents output_token_list, seq_lengths = self.text_decoder.generate_captions( text_latents, self.text_tokenizer.eos_token_id, device=device ) output_list = output_token_list.cpu().numpy() gen_text = [ self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) for output, length in zip(output_list, seq_lengths) ] # 10. Convert to PIL if output_type == "pil" and gen_image is not None: gen_image = self.numpy_to_pil(gen_image) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (gen_image, gen_text) return ImageTextPipelineOutput(images=gen_image, text=gen_text) ================================================ FILE: 6DoF/diffusers/pipelines/versatile_diffusion/__init__.py ================================================ from ...utils import ( OptionalDependencyNotAvailable, is_torch_available, is_transformers_available, is_transformers_version, ) try: if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, ) else: from .modeling_text_unet import UNetFlatConditionModel from .pipeline_versatile_diffusion import VersatileDiffusionPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline ================================================ FILE: 6DoF/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py ================================================ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.activations import get_activation from ...models.attention import Attention from ...models.attention_processor import ( AttentionProcessor, AttnAddedKVProcessor, AttnAddedKVProcessor2_0, AttnProcessor, ) from ...models.dual_transformer_2d import DualTransformer2DModel from ...models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput from ...utils import is_torch_version, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, num_attention_heads, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlockFlat": return DownBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "CrossAttnDownBlockFlat": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") return CrossAttnDownBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{down_block_type} is not supported.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, num_attention_heads, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlockFlat": return UpBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "CrossAttnUpBlockFlat": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") return CrossAttnUpBlockFlat( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) raise ValueError(f"{up_block_type} is not supported.") # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat class UNetFlatConditionModel(ModelMixin, ConfigMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`], [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat", ), mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn", up_block_types: Tuple[str] = ( "UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads`" " because of a naming issue as described in" " https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing" " `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( "Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:" f" {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( "Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:" f" {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( "Must provide the same number of `only_cross_attention` as `down_block_types`." f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( "Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:" f" {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:" f" {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( "Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:" f" {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:" f" {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = LinearMultiDim( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlockFlatCrossAttn": self.mid_block = UNetMidBlockFlatCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn": self.mid_block = UNetMidBlockFlatSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = LinearMultiDim( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNetFlatConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires" " the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires" " the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires" " the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the" " keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires" " the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which" " requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires" " the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) class LinearMultiDim(nn.Linear): def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) if out_features is None: out_features = in_features out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) self.in_features_multidim = in_features self.out_features_multidim = out_features super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) def forward(self, input_tensor, *args, **kwargs): shape = input_tensor.shape n_dim = len(self.in_features_multidim) input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) output_tensor = super().forward(input_tensor) output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) return output_tensor class ResnetBlockFlat(nn.Module): def __init__( self, *, in_channels, out_channels=None, dropout=0.0, temb_channels=512, groups=32, groups_out=None, pre_norm=True, eps=1e-6, time_embedding_norm="default", use_in_shortcut=None, second_dim=4, **kwargs, ): super().__init__() self.pre_norm = pre_norm self.pre_norm = True in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) self.in_channels_prod = np.array(in_channels).prod() self.channels_multidim = in_channels if out_channels is not None: out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) out_channels_prod = np.array(out_channels).prod() self.out_channels_multidim = out_channels else: out_channels_prod = self.in_channels_prod self.out_channels_multidim = self.channels_multidim self.time_embedding_norm = time_embedding_norm if groups_out is None: groups_out = groups self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) if temb_channels is not None: self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) else: self.time_emb_proj = None self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) self.nonlinearity = nn.SiLU() self.use_in_shortcut = ( self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut ) self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d( self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 ) def forward(self, input_tensor, temb): shape = input_tensor.shape n_dim = len(self.channels_multidim) input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) hidden_states = input_tensor hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) if temb is not None: temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] hidden_states = hidden_states + temb hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: input_tensor = self.conv_shortcut(input_tensor) output_tensor = input_tensor + hidden_states output_tensor = output_tensor.view(*shape[0:-n_dim], -1) output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) return output_tensor # Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim class DownBlockFlat(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ LinearMultiDim( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states # Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim class CrossAttnDownBlockFlat(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ LinearMultiDim( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, None, # timestep None, # class_labels cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states # Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim class UpBlockFlat(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states # Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim class CrossAttnUpBlockFlat(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlockFlat( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, None, # timestep None, # class_labels cross_attention_kwargs, attention_mask, encoder_attention_mask, **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states # Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] for _ in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) return hidden_states # Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat class UNetMidBlockFlatSimpleCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() self.has_cross_attention = True self.attention_head_dim = attention_head_dim resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.num_heads = in_channels // self.attention_head_dim # there is always at least one resnet resnets = [ ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ] attentions = [] for _ in range(num_layers): processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=in_channels, cross_attention_dim=in_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) resnets.append( ResnetBlockFlat( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) # resnet hidden_states = resnet(hidden_states, temb) return hidden_states ================================================ FILE: 6DoF/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py ================================================ import inspect from typing import Callable, List, Optional, Union import PIL.Image import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging from ..pipeline_utils import DiffusionPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionMegaSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ tokenizer: CLIPTokenizer image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModel image_encoder: CLIPVisionModel image_unet: UNet2DConditionModel text_unet: UNet2DConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers def __init__( self, tokenizer: CLIPTokenizer, image_feature_extractor: CLIPImageProcessor, text_encoder: CLIPTextModel, image_encoder: CLIPVisionModel, image_unet: UNet2DConditionModel, text_unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, image_feature_extractor=image_feature_extractor, text_encoder=text_encoder, image_encoder=image_encoder, image_unet=image_unet, text_unet=text_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) @torch.no_grad() def image_variation( self, image: Union[torch.FloatTensor, PIL.Image.Image], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): The image prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> pipe = VersatileDiffusionPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe.image_variation(image, generator=generator).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() components = {name: component for name, component in self.components.items() if name in expected_components} return VersatileDiffusionImageVariationPipeline(**components)( image=image, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, ) @torch.no_grad() def text_to_image( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionPipeline >>> import torch >>> pipe = VersatileDiffusionPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] >>> image.save("./astronaut.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() components = {name: component for name, component in self.components.items() if name in expected_components} temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) output = temp_pipeline( prompt=prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, ) # swap the attention blocks back to the original state temp_pipeline._swap_unet_attention_blocks() return output @torch.no_grad() def dual_guided( self, prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], image: Union[str, List[str]], text_to_image_strength: float = 0.5, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> text = "a red car in the sun" >>> pipe = VersatileDiffusionPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> text_to_image_strength = 0.75 >>> image = pipe.dual_guided( ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator ... ).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() components = {name: component for name, component in self.components.items() if name in expected_components} temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) output = temp_pipeline( prompt=prompt, image=image, text_to_image_strength=text_to_image_strength, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, ) temp_pipeline._revert_dual_attention() return output ================================================ FILE: 6DoF/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Tuple, Union import numpy as np import PIL import torch import torch.utils.checkpoint from transformers import ( CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. tokenizer (`transformers.BertTokenizer`): Tokenizer of class [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ tokenizer: CLIPTokenizer image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModelWithProjection image_encoder: CLIPVisionModelWithProjection image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers _optional_components = ["text_unet"] def __init__( self, tokenizer: CLIPTokenizer, image_feature_extractor: CLIPImageProcessor, text_encoder: CLIPTextModelWithProjection, image_encoder: CLIPVisionModelWithProjection, image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, image_feature_extractor=image_feature_extractor, text_encoder=text_encoder, image_encoder=image_encoder, image_unet=image_unet, text_unet=text_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None and ( "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention ): # if loading from a universal checkpoint rather than a saved dual-guided pipeline self._convert_to_dual_attention() def remove_unused_weights(self): self.register_modules(text_unet=None) def _convert_to_dual_attention(self): """ Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks from both `image_unet` and `text_unet` """ for name, module in self.image_unet.named_modules(): if isinstance(module, Transformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) image_transformer = self.image_unet.get_submodule(parent_name)[index] text_transformer = self.text_unet.get_submodule(parent_name)[index] config = image_transformer.config dual_transformer = DualTransformer2DModel( num_attention_heads=config.num_attention_heads, attention_head_dim=config.attention_head_dim, in_channels=config.in_channels, num_layers=config.num_layers, dropout=config.dropout, norm_num_groups=config.norm_num_groups, cross_attention_dim=config.cross_attention_dim, attention_bias=config.attention_bias, sample_size=config.sample_size, num_vector_embeds=config.num_vector_embeds, activation_fn=config.activation_fn, num_embeds_ada_norm=config.num_embeds_ada_norm, ) dual_transformer.transformers[0] = image_transformer dual_transformer.transformers[1] = text_transformer self.image_unet.get_submodule(parent_name)[index] = dual_transformer self.image_unet.register_to_config(dual_cross_attention=True) def _revert_dual_attention(self): """ Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` """ for name, module in self.image_unet.named_modules(): if isinstance(module, DualTransformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] self.image_unet.register_to_config(dual_cross_attention=False) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.image_unet, "_hf_hook"): return self.device for module in self.image_unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not """ def normalize_embeddings(encoder_output): embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) embeds_pooled = encoder_output.text_embeds embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) return embeds batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = normalize_embeddings(prompt_embeds) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens = [""] * batch_size max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not """ def normalize_embeddings(encoder_output): embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) embeds = self.image_encoder.visual_projection(embeds) embeds_pooled = embeds[:, 0:1] embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) return embeds batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) image_embeddings = self.image_encoder(pixel_values) image_embeddings = normalize_embeddings(image_embeddings) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) negative_prompt_embeds = self.image_encoder(pixel_values) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and conditional embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, prompt, image, height, width, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")): for name, module in self.image_unet.named_modules(): if isinstance(module, DualTransformer2DModel): module.mix_ratio = mix_ratio for i, type in enumerate(condition_types): if type == "text": module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings module.transformer_index_for_condition[i] = 1 # use the second (text) transformer else: module.condition_lengths[i] = 257 module.transformer_index_for_condition[i] = 0 # use the first (image) transformer @torch.no_grad() def __call__( self, prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], image: Union[str, List[str]], text_to_image_strength: float = 0.5, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionDualGuidedPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> text = "a red car in the sun" >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe.remove_unused_weights() >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> text_to_image_strength = 0.75 >>> image = pipe( ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator ... ).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet height = height or self.image_unet.config.sample_size * self.vae_scale_factor width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, image, height, width, callback_steps) # 2. Define call parameters prompt = [prompt] if not isinstance(prompt, list) else prompt image = [image] if not isinstance(image, list) else image batch_size = len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompts prompt_embeds = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) dual_prompt_embeddings = torch.cat([prompt_embeds, image_embeddings], dim=1) prompt_types = ("text", "image") # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, dual_prompt_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Combine the attention blocks of the image and text UNets self.set_transformer_params(text_to_image_strength, prompt_types) # 8. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch import torch.utils.checkpoint from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. tokenizer (`transformers.BertTokenizer`): Tokenizer of class [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ image_feature_extractor: CLIPImageProcessor image_encoder: CLIPVisionModelWithProjection image_unet: UNet2DConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers def __init__( self, image_feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection, image_unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( image_feature_extractor=image_feature_extractor, image_encoder=image_encoder, image_unet=image_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.image_unet, "_hf_hook"): return self.device for module in self.image_unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ def normalize_embeddings(encoder_output): embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) embeds = self.image_encoder.visual_projection(embeds) embeds_pooled = embeds[:, 0:1] embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) return embeds if isinstance(prompt, torch.Tensor) and len(prompt.shape) == 4: prompt = list(prompt) batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) image_embeddings = self.image_encoder(pixel_values) image_embeddings = normalize_embeddings(image_embeddings) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_images: List[str] if negative_prompt is None: uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, PIL.Image.Image): uncond_images = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_images = negative_prompt uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) negative_prompt_embeds = self.image_encoder(pixel_values) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and conditional embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_image_variation.StableDiffusionImageVariationPipeline.check_inputs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): The image prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionImageVariationPipeline >>> import torch >>> import requests >>> from io import BytesIO >>> from PIL import Image >>> # let's download an initial image >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" >>> response = requests.get(url) >>> image = Image.open(BytesIO(response.content)).convert("RGB") >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe(image, generator=generator).images[0] >>> image.save("./car_variation.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.image_unet.config.sample_size * self.vae_scale_factor width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt image_embeddings = self._encode_prompt( image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import warnings from typing import Callable, List, Optional, Union import torch import torch.utils.checkpoint from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. tokenizer (`transformers.BertTokenizer`): Tokenizer of class [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ tokenizer: CLIPTokenizer image_feature_extractor: CLIPImageProcessor text_encoder: CLIPTextModelWithProjection image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL scheduler: KarrasDiffusionSchedulers _optional_components = ["text_unet"] def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, image_unet=image_unet, text_unet=text_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None: self._swap_unet_attention_blocks() def _swap_unet_attention_blocks(self): """ Swap the `Transformer2DModel` blocks between the image and text UNets """ for name, module in self.image_unet.named_modules(): if isinstance(module, Transformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( self.text_unet.get_submodule(parent_name)[index], self.image_unet.get_submodule(parent_name)[index], ) def remove_unused_weights(self): self.register_modules(text_unet=None) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.image_unet, "_hf_hook"): return self.device for module in self.image_unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ def normalize_embeddings(encoder_output): embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) embeds_pooled = encoder_output.text_embeds embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) return embeds batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = normalize_embeddings(prompt_embeds) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = normalize_embeddings(negative_prompt_embeds) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionTextToImagePipeline >>> import torch >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe.remove_unused_weights() >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] >>> image.save("./astronaut.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.image_unet.config.sample_size * self.vae_scale_factor width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.image_unet.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: 6DoF/diffusers/pipelines/vq_diffusion/__init__.py ================================================ from ...utils import is_torch_available, is_transformers_available if is_transformers_available() and is_torch_available(): from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline ================================================ FILE: 6DoF/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py ================================================ # Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTokenizer from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin, Transformer2DModel, VQModel from ...schedulers import VQDiffusionScheduler from ...utils import logging from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): """ Utility class for storing learned text embeddings for classifier free sampling """ @register_to_config def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): super().__init__() self.learnable = learnable if self.learnable: assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" assert length is not None, "learnable=True requires `length` to be set" embeddings = torch.zeros(length, hidden_size) else: embeddings = None self.embeddings = torch.nn.Parameter(embeddings) class VQDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using VQ Diffusion This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vqvae ([`VQModel`]): Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. VQ Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). transformer ([`Transformer2DModel`]): Conditional transformer to denoise the encoded image latents. scheduler ([`VQDiffusionScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ vqvae: VQModel text_encoder: CLIPTextModel tokenizer: CLIPTokenizer transformer: Transformer2DModel learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings scheduler: VQDiffusionScheduler def __init__( self, vqvae: VQModel, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, transformer: Transformer2DModel, scheduler: VQDiffusionScheduler, learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, ): super().__init__() self.register_modules( vqvae=vqvae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, ) def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): batch_size = len(prompt) if isinstance(prompt, list) else 1 # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) text_input_ids = text_inputs.input_ids if text_input_ids.shape[-1] > self.tokenizer.model_max_length: removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] prompt_embeds = self.text_encoder(text_input_ids.to(self.device))[0] # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. # While CLIP does normalize the pooled output of the text transformer when combining # the image and text embeddings, CLIP does not directly normalize the last hidden state. # # CLIP normalizing the pooled output. # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) # duplicate text embeddings for each generation per prompt prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: if self.learned_classifier_free_sampling_embeddings.learnable: negative_prompt_embeds = self.learned_classifier_free_sampling_embeddings.embeddings negative_prompt_embeds = negative_prompt_embeds.unsqueeze(0).repeat(batch_size, 1, 1) else: uncond_tokens = [""] * batch_size max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # See comment for normalizing text embeddings negative_prompt_embeds = negative_prompt_embeds / negative_prompt_embeds.norm(dim=-1, keepdim=True) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], num_inference_steps: int = 100, guidance_scale: float = 5.0, truncation_rate: float = 1.0, num_images_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor` of shape (batch), *optional*): Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated of completely masked latent pixels. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # get the initial completely masked latents unless the user supplied it latents_shape = (batch_size, self.transformer.num_latent_pixels) if latents is None: mask_class = self.transformer.num_vector_embeds - 1 latents = torch.full(latents_shape, mask_class).to(self.device) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any(): raise ValueError( "Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0," f" {self.transformer.num_vector_embeds - 1} (inclusive)." ) latents = latents.to(self.device) # set timesteps self.scheduler.set_timesteps(num_inference_steps, device=self.device) timesteps_tensor = self.scheduler.timesteps.to(self.device) sample = latents for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the sample if we are doing classifier free guidance latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample # predict the un-noised image # model_output == `log_p_x_0` model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample if do_classifier_free_guidance: model_output_uncond, model_output_text = model_output.chunk(2) model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) model_output = self.truncate(model_output, truncation_rate) # remove `log(0)`'s (`-inf`s) model_output = model_output.clamp(-70) # compute the previous noisy sample x_t -> x_t-1 sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, sample) embedding_channels = self.vqvae.config.vq_embed_dim embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels) embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape) image = self.vqvae.decode(embeddings, force_not_quantize=True).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image) def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor: """ Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero. """ sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True) sorted_p_x_0 = torch.exp(sorted_log_p_x_0) keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate # Ensure that at least the largest probability is not zeroed out all_true = torch.full_like(keep_mask[:, 0:1, :], True) keep_mask = torch.cat((all_true, keep_mask), dim=1) keep_mask = keep_mask[:, :-1, :] keep_mask = keep_mask.gather(1, indices.argsort(1)) rv = log_p_x_0.clone() rv[~keep_mask] = -torch.inf # -inf = log(0) return rv ================================================ FILE: 6DoF/diffusers/schedulers/__init__.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..utils import ( OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available, is_torchsde_available, ) try: if not is_torch_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_deis_multistep import DEISMultistepScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_repaint import RePaintScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_unclip import UnCLIPScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler try: if not is_flax_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_flax_objects import * # noqa F403 else: from .scheduling_ddim_flax import FlaxDDIMScheduler from .scheduling_ddpm_flax import FlaxDDPMScheduler from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler from .scheduling_utils_flax import ( FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, ) try: if not (is_torch_available() and is_scipy_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 else: from .scheduling_lms_discrete import LMSDiscreteScheduler try: if not (is_torch_available() and is_torchsde_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 else: from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_consistency_models.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, randn_tensor from .scheduling_utils import SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class CMStochasticIterativeSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.FloatTensor class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): """ Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in the paper [1]. [1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models" https://arxiv.org/pdf/2303.01469 [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. sigma_min (`float`): Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the original implementation. sigma_max (`float`): Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the original implementation. sigma_data (`float`): The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the original implementation, which is also the original value suggested in the EDM paper. s_noise (`float`): The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. This was set to 1.0 in the original implementation. rho (`float`): The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This was set to 7.0 in the original implementation, which is also the original value suggested in the EDM paper. clip_denoised (`bool`): Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`. timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*): Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing order. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 40, sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, s_noise: float = 1.0, rho: float = 7.0, clip_denoised: bool = True, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max ramp = np.linspace(0, 1, num_train_timesteps) sigmas = self._convert_to_karras(ramp) timesteps = self.sigma_to_t(sigmas) # setable values self.num_inference_steps = None self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps) self.custom_timesteps = False self.is_scale_input_called = False def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() return indices.item() def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`, following the EDM model. Args: sample (`torch.FloatTensor`): input sample timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: scaled input sample """ # Get sigma corresponding to timestep if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_idx = self.index_for_timestep(timestep) sigma = self.sigmas[step_idx] sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5) self.is_scale_input_called = True return sample def sigma_to_t(self, sigmas: Union[float, np.ndarray]): """ Gets scaled timesteps from the Karras sigmas, for input to the consistency model. Args: sigmas (`float` or `np.ndarray`): single Karras sigma or array of Karras sigmas Returns: `float` or `np.ndarray`: scaled input timestep or scaled input timestep array """ if not isinstance(sigmas, np.ndarray): sigmas = np.array(sigmas, dtype=np.float64) timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44) return timesteps def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, optional): custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` must be `None`. """ if num_inference_steps is None and timesteps is None: raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.") if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.") # Follow DDPMScheduler custom timesteps logic if timesteps is not None: for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`timesteps` must be in descending order.") if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) self.custom_timesteps = True else: if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.custom_timesteps = False # Map timesteps to Karras sigmas directly for multistep sampling # See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675 num_train_timesteps = self.config.num_train_timesteps ramp = timesteps[::-1].copy() ramp = ramp / (num_train_timesteps - 1) sigmas = self._convert_to_karras(ramp) timesteps = self.sigma_to_t(sigmas) sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) # Modified _convert_to_karras implementation that takes in ramp as argument def _convert_to_karras(self, ramp): """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = self.config.sigma_min sigma_max: float = self.config.sigma_max rho = self.config.rho min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def get_scalings(self, sigma): sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out def get_scalings_for_boundary_condition(self, sigma): """ Gets the scalings used in the consistency model parameterization, following Appendix C of the original paper. This enforces the consistency model boundary condition. Note that `epsilon` in the equations for c_skip and c_out is set to sigma_min. Args: sigma (`torch.FloatTensor`): The current sigma in the Karras sigma schedule. Returns: `tuple`: A two-element tuple where c_skip (which weights the current sample) is the first element and c_out (which weights the consistency model output) is the second element. """ sigma_min = self.config.sigma_min sigma_data = self.config.sigma_data c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2) c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator (`torch.Generator`, *optional*): Random number generator. return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class Returns: [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" f" `{self.__class__}.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) sigma_min = self.config.sigma_min sigma_max = self.config.sigma_max step_index = self.index_for_timestep(timestep) # sigma_next corresponds to next_t in original implementation sigma = self.sigmas[step_index] if step_index + 1 < self.config.num_train_timesteps: sigma_next = self.sigmas[step_index + 1] else: # Set sigma_next to sigma_min sigma_next = self.sigmas[-1] # Get scalings for boundary conditions c_skip, c_out = self.get_scalings_for_boundary_condition(sigma) # 1. Denoise model output using boundary conditions denoised = c_out * model_output + c_skip * sample if self.config.clip_denoised: denoised = denoised.clamp(-1, 1) # 2. Sample z ~ N(0, s_noise^2 * I) # Noise is not used for onestep sampling. if len(self.timesteps) > 1: noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) else: noise = torch.zeros_like(model_output) z = noise * self.config.s_noise sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max) # 3. Return noisy sample # tau = sigma_hat, eps = sigma_min prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5 if not return_dict: return (prev_sample,) return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_ddim.py ================================================ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.FloatTensor`): the betas that the scheduler is being initialized with. Returns: `torch.FloatTensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DDIMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. rescale_betas_zero_snr (`bool`, default `False`): whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). This can enable the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. generator: random number generator. variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://arxiv.org/abs/2210.05559) return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_ddim_flax.py ================================================ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, get_velocity_common, ) @flax.struct.dataclass class DDIMSchedulerState: common: CommonSchedulerState final_alpha_cumprod: jnp.ndarray # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod def create( cls, common: CommonSchedulerState, final_alpha_cumprod: jnp.ndarray, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, ): return cls( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) @dataclass class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput): state: DDIMSchedulerState class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. `v-prediction` is not supported for this scheduler. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: if common is None: common = CommonSchedulerState.create(self) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. final_alpha_cumprod = ( jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] ) # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return DDIMSchedulerState.create( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def scale_model_input( self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def set_timesteps( self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, ) def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def step( self, state: DDIMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, eta: float = 0.0, return_dict: bool = True, ) -> Union[FlaxDDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class Returns: [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps alphas_cumprod = state.common.alphas_cumprod final_alpha_cumprod = state.final_alpha_cumprod # 2. compute alphas, betas alpha_prod_t = alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(state, timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if not return_dict: return (prev_sample, state) return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, state: DDIMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def get_velocity( self, state: DDIMSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return get_velocity_common(state.common, sample, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_ddim_inverse.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import BaseOutput, deprecate @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): """ DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`]. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_zero (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `0`, otherwise it uses the value of alpha at step `num_train_timesteps - 1`. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_zero=False`, to make the last step use step `num_train_timesteps - 1` for the previous alpha product. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_zero: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", clip_sample_range: float = 1.0, **kwargs, ): if kwargs.get("set_alpha_to_one", None) is not None: deprecation_message = ( "The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead." ) deprecate("set_alpha_to_one", "1.0.0", deprecation_message, standard_warn=False) set_alpha_to_zero = kwargs["set_alpha_to_one"] if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in inverted ddim, we are looking into the next alphas_cumprod # For the final step, there is no next alphas_cumprod, and the index is out of bounds # `set_alpha_to_zero` decides whether we set this parameter simply to zero # in this case, self.step() just output the predicted noise # or whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.timesteps += self.config.steps_offset def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: # 1. get previous step value (=t+1) prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas # change original implementation to exactly match noise levels for analogous forward process alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = ( self.alphas_cumprod[prev_timestep] if prev_timestep < self.config.num_train_timesteps else self.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if not return_dict: return (prev_sample, pred_original_sample) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_ddim_parallel.py ================================================ # Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion # and https://github.com/hojonathanho/diffusion import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput class DDIMParallelSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.FloatTensor`): the betas that the scheduler is being initialized with. Returns: `torch.FloatTensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DDIMParallelScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. set_alpha_to_one (`bool`, default `True`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. rescale_betas_zero_snr (`bool`, default `False`): whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). This can enable the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 _is_ode_scheduler = True @register_to_config # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.__init__ def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def _get_variance(self, timestep, prev_timestep=None): if prev_timestep is None: prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def _batch_get_variance(self, t, prev_t): alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[DDIMParallelSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. generator: random number generator. variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://arxiv.org/abs/2210.05559) return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDIMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._get_variance(timestep, prev_timestep) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" " `variance_noise` stays `None`." ) if variance_noise is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) variance = std_dev_t * variance_noise prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return DDIMParallelSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def batch_step_no_noise( self, model_output: torch.FloatTensor, timesteps: List[int], sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, ) -> torch.FloatTensor: """ Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise is pre-sampled by the pipeline. Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timesteps (`List[int]`): current discrete timesteps in the diffusion chain. This is now a list of integers. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. Returns: `torch.FloatTensor`: sample tensor at previous timestep. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) assert eta == 0.0 # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding # Notation ( -> # - pred_noise_t -> e_theta(x_t, t) # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) t = timesteps prev_t = t - self.config.num_train_timesteps // self.num_inference_steps t = t.view(-1, *([1] * (model_output.ndim - 1))) prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) # 1. compute alphas, betas self.alphas_cumprod = self.alphas_cumprod.to(model_output.device) self.final_alpha_cumprod = self.final_alpha_cumprod.to(model_output.device) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`" ) # 4. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) variance = self._batch_get_variance(t, prev_t).to(model_output.device).view(*alpha_prod_t_prev.shape) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction return prev_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_ddpm.py ================================================ # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass class DDPMSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DDPMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.custom_timesteps = False self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`Optional[int]`): the number of diffusion steps used when generating samples with a pre-trained model. If passed, then `timesteps` must be `None`. device (`str` or `torch.device`, optional): the device to which the timesteps are moved to. custom_timesteps (`List[int]`, optional): custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` must be `None`. """ if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") if timesteps is not None: for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`custom_timesteps` must be in descending order.") if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) self.custom_timesteps = True else: if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps self.custom_timesteps = False # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, predicted_variance=None, variance_type=None): prev_t = self.previous_timestep(t) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t # we always take the log of variance, so clamp it to ensure it's not 0 variance = torch.clamp(variance, min=1e-20) if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": variance = variance # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(variance) variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = current_beta_t elif variance_type == "fixed_large_log": # Glide max_log variance = torch.log(current_beta_t) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = torch.log(variance) max_log = torch.log(current_beta_t) frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep prev_t = self.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise variance = 0 if t > 0: device = model_output.device variance_noise = randn_tensor( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise elif self.variance_type == "learned_range": variance = self._get_variance(t, predicted_variance=predicted_variance) variance = torch.exp(0.5 * variance) * variance_noise else: variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample,) return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps def previous_timestep(self, timestep): if self.custom_timesteps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: num_inference_steps = ( self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps ) prev_t = timestep - self.config.num_train_timesteps // num_inference_steps return prev_t ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_ddpm_flax.py ================================================ # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, get_velocity_common, ) @flax.struct.dataclass class DDPMSchedulerState: common: CommonSchedulerState # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) @dataclass class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput): state: DDPMSchedulerState class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. `v-prediction` is not supported for this scheduler. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: if common is None: common = CommonSchedulerState.create(self) # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return DDPMSchedulerState.create( common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def scale_model_input( self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def set_timesteps( self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> DDPMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`DDIMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, ) def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, variance_type=None): alpha_prod_t = state.common.alphas_cumprod[t] alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t] if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": variance = jnp.clip(variance, a_min=1e-20) # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = jnp.log(jnp.clip(variance, a_min=1e-20)) elif variance_type == "fixed_large": variance = state.common.betas[t] elif variance_type == "fixed_large_log": # Glide max_log variance = jnp.log(state.common.betas[t]) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = variance max_log = state.common.betas[t] frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def step( self, state: DDPMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, key: Optional[jax.random.KeyArray] = None, return_dict: bool = True, ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. key (`jax.random.KeyArray`): a PRNG key. return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: [`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep if key is None: key = jax.random.PRNGKey(0) if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = state.common.alphas_cumprod[t] alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype)) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " " for the FlaxDDPMScheduler." ) # 3. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = jnp.clip(pred_original_sample, -1, 1) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise def random_variance(): split_key = jax.random.split(key, num=1) noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise variance = jnp.where(t > 0, random_variance(), jnp.zeros(model_output.shape, dtype=self.dtype)) pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample, state) return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state) def add_noise( self, state: DDPMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def get_velocity( self, state: DDPMSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return get_velocity_common(state.common, sample, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_ddpm_parallel.py ================================================ # Copyright 2023 ParaDiGMS authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput class DDPMParallelSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 _is_ode_scheduler = False @register_to_config # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.__init__ def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.custom_timesteps = False self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.set_timesteps def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, ): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`Optional[int]`): the number of diffusion steps used when generating samples with a pre-trained model. If passed, then `timesteps` must be `None`. device (`str` or `torch.device`, optional): the device to which the timesteps are moved to. custom_timesteps (`List[int]`, optional): custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps` must be `None`. """ if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") if timesteps is not None: for i in range(1, len(timesteps)): if timesteps[i] >= timesteps[i - 1]: raise ValueError("`custom_timesteps` must be in descending order.") if timesteps[0] >= self.config.num_train_timesteps: raise ValueError( f"`timesteps` must start before `self.config.train_timesteps`:" f" {self.config.num_train_timesteps}." ) timesteps = np.array(timesteps, dtype=np.int64) self.custom_timesteps = True else: if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.config.num_train_timesteps} timesteps." ) self.num_inference_steps = num_inference_steps self.custom_timesteps = False # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) .round()[::-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) self.timesteps = torch.from_numpy(timesteps).to(device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance def _get_variance(self, t, predicted_variance=None, variance_type=None): prev_t = self.previous_timestep(t) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t # we always take the log of variance, so clamp it to ensure it's not 0 variance = torch.clamp(variance, min=1e-20) if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small": variance = variance # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(variance) variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = current_beta_t elif variance_type == "fixed_large_log": # Glide max_log variance = torch.log(current_beta_t) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = torch.log(variance) max_log = torch.log(current_beta_t) frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True, ) -> Union[DDPMParallelSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDPMParallelSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep prev_t = self.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise variance = 0 if t > 0: device = model_output.device variance_noise = randn_tensor( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise elif self.variance_type == "learned_range": variance = self._get_variance(t, predicted_variance=predicted_variance) variance = torch.exp(0.5 * variance) * variance_noise else: variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample,) return DDPMParallelSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def batch_step_no_noise( self, model_output: torch.FloatTensor, timesteps: List[int], sample: torch.FloatTensor, ) -> torch.FloatTensor: """ Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once. Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise is pre-sampled by the pipeline. Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timesteps (`List[int]`): current discrete timesteps in the diffusion chain. This is now a list of integers. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: sample tensor at previous timestep. """ t = timesteps num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps prev_t = t - self.config.num_train_timesteps // num_inference_steps t = t.view(-1, *([1] * (model_output.ndim - 1))) prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: pass # 1. compute alphas, betas self.alphas_cumprod = self.alphas_cumprod.to(model_output.device) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMParallelScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample return pred_prev_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity def get_velocity( self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as sample alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) timesteps = timesteps.to(sample.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(sample.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity def __len__(self): return self.config.num_train_timesteps # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): if self.custom_timesteps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: num_inference_steps = ( self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps ) prev_t = timestep - self.config.num_train_timesteps // num_inference_steps return prev_t ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_deis_multistep.py ================================================ # Copyright 2023 FLAIR Lab and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): """ DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver. More variants of DEIS can be found in https://github.com/qsh-zh/deis. Currently, we support the log-rho multistep DEIS. We recommend to use `solver_order=2 / 3` while `solver_order=1` reduces to DDIM. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set `thresholding=True` to use the dynamic thresholding. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DEIS; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, or `v-prediction`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` algorithm_type (`str`, default `deis`): the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in the future lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "deis", solver_type: str = "logrho", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DEIS if algorithm_type not in ["deis"]: if algorithm_type in ["dpmsolver", "dpmsolver++"]: self.register_to_config(algorithm_type="deis") else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["logrho"]: if solver_type in ["midpoint", "heun", "bh1", "bh2"]: self.register_to_config(solver_type="logrho") else: raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) self.sigmas = torch.from_numpy(sigmas) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ Convert the model output to the corresponding type that the algorithm DEIS needs. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ if self.config.prediction_type == "epsilon": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DEISMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) if self.config.algorithm_type == "deis": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] return (sample - alpha_t * x0_pred) / sigma_t else: raise NotImplementedError("only support log-rho multistep deis now") def deis_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the first-order DEIS (equivalent to DDIM). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s if self.config.algorithm_type == "deis": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output else: raise NotImplementedError("only support log-rho multistep deis now") return x_t def multistep_deis_second_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the second-order multistep DEIS. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1] sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1] rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 if self.config.algorithm_type == "deis": def ind_fn(t, b, c): # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}] return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c)) coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1) coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0) x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1) return x_t else: raise NotImplementedError("only support log-rho multistep deis now") def multistep_deis_third_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the third-order multistep DEIS. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2] sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2] rho_t, rho_s0, rho_s1, rho_s2 = ( sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1, simga_s2 / alpha_s2, ) if self.config.algorithm_type == "deis": def ind_fn(t, b, c, d): # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}] numerator = t * ( np.log(c) * (np.log(d) - np.log(t) + 1) - np.log(d) * np.log(t) + np.log(d) + np.log(t) ** 2 - 2 * np.log(t) + 2 ) denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d)) return numerator / denominator coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2) coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0) coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1) x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2) return x_t else: raise NotImplementedError("only support log-rho multistep deis now") def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the multistep DEIS. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] lower_order_final = ( (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, timestep, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_deis_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample ) else: timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_deis_third_order_update( self.model_outputs, timestep_list, prev_timestep, sample ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_dpmsolver_multistep.py ================================================ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality samples, and it can generate quite good samples even in only 10 steps. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the second-order `sde-dpmsolver++`. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. lambda_min_clipped (`float`, default `-inf`): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = last_timestep // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) self.sigmas = torch.from_numpy(sigmas) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: epsilon = model_output[:, :3] else: epsilon = model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the first-order DPM-Solver (equivalent to DDIM). See https://arxiv.org/abs/2206.00927 for the detailed derivation. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None x_t = ( (alpha_t / alpha_s) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the second-order multistep DPM-Solver. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * (torch.exp(h) - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the third-order multistep DPM-Solver. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2], ) alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the multistep DPM-Solver. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] lower_order_final = ( (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, timestep, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) else: noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update( model_output, timestep, prev_timestep, sample, noise=noise ) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample, noise=noise ) else: timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update( self.model_outputs, timestep_list, prev_timestep, sample ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py ================================================ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver from dataclasses import dataclass from typing import List, Optional, Tuple, Union import flax import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, ) @flax.struct.dataclass class DPMSolverMultistepSchedulerState: common: CommonSchedulerState alpha_t: jnp.ndarray sigma_t: jnp.ndarray lambda_t: jnp.ndarray # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None # running values model_outputs: Optional[jnp.ndarray] = None lower_order_nums: Optional[jnp.int32] = None prev_timestep: Optional[jnp.int32] = None cur_sample: Optional[jnp.ndarray] = None @classmethod def create( cls, common: CommonSchedulerState, alpha_t: jnp.ndarray, sigma_t: jnp.ndarray, lambda_t: jnp.ndarray, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, ): return cls( common=common, alpha_t=alpha_t, sigma_t=sigma_t, lambda_t=lambda_t, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) @dataclass class FlaxDPMSolverMultistepSchedulerOutput(FlaxSchedulerOutput): state: DPMSolverMultistepSchedulerState class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality samples, and it can generate quite good samples even in only 10 steps. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, or `v-prediction`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: if common is None: common = CommonSchedulerState.create(self) # Currently we only support VP-type noise schedule alpha_t = jnp.sqrt(common.alphas_cumprod) sigma_t = jnp.sqrt(1 - common.alphas_cumprod) lambda_t = jnp.log(alpha_t) - jnp.log(sigma_t) # settings for DPM-Solver if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]: raise NotImplementedError(f"{self.config.algorithm_type} does is not implemented for {self.__class__}") if self.config.solver_type not in ["midpoint", "heun"]: raise NotImplementedError(f"{self.config.solver_type} does is not implemented for {self.__class__}") # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return DPMSolverMultistepSchedulerState.create( common=common, alpha_t=alpha_t, sigma_t=sigma_t, lambda_t=lambda_t, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def set_timesteps( self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple ) -> DPMSolverMultistepSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. shape (`Tuple`): the shape of the samples to be generated. """ timesteps = ( jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .astype(jnp.int32) ) # initial running values model_outputs = jnp.zeros((self.config.solver_order,) + shape, dtype=self.dtype) lower_order_nums = jnp.int32(0) prev_timestep = jnp.int32(-1) cur_sample = jnp.zeros(shape, dtype=self.dtype) return state.replace( num_inference_steps=num_inference_steps, timesteps=timesteps, model_outputs=model_outputs, lower_order_nums=lower_order_nums, prev_timestep=prev_timestep, cur_sample=cur_sample, ) def convert_model_output( self, state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." ) if self.config.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 dynamic_max_val = jnp.percentile( jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) ) dynamic_max_val = jnp.maximum( dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) ) x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = state.alpha_t[timestep], state.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." ) def dpm_solver_first_order_update( self, state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ One step for the first-order DPM-Solver (equivalent to DDIM). See https://arxiv.org/abs/2206.00927 for the detailed derivation. Args: model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0 = prev_timestep, timestep m0 = model_output lambda_t, lambda_s = state.lambda_t[t], state.lambda_t[s0] alpha_t, alpha_s = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s = state.sigma_t[t], state.sigma_t[s0] h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * m0 elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * m0 return x_t def multistep_dpm_solver_second_order_update( self, state: DPMSolverMultistepSchedulerState, model_output_list: jnp.ndarray, timestep_list: List[int], prev_timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ One step for the second-order multistep DPM-Solver. Args: model_output_list (`List[jnp.ndarray]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (jnp.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (jnp.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 ) return x_t def multistep_dpm_solver_third_order_update( self, state: DPMSolverMultistepSchedulerState, model_output_list: jnp.ndarray, timestep_list: List[int], prev_timestep: int, sample: jnp.ndarray, ) -> jnp.ndarray: """ One step for the third-order multistep DPM-Solver. Args: model_output_list (`List[jnp.ndarray]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1], state.lambda_t[s2], ) alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (jnp.exp(-h) - 1.0)) * D0 + (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (jnp.exp(h) - 1.0)) * D0 - (sigma_t * ((jnp.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((jnp.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def step( self, state: DPMSolverMultistepSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxDPMSolverMultistepSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by DPM-Solver. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxDPMSolverMultistepSchedulerOutput class Returns: [`FlaxDPMSolverMultistepSchedulerOutput`] or `tuple`: [`FlaxDPMSolverMultistepSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) (step_index,) = jnp.where(state.timesteps == timestep, size=1) step_index = step_index[0] prev_timestep = jax.lax.select(step_index == len(state.timesteps) - 1, 0, state.timesteps[step_index + 1]) model_output = self.convert_model_output(state, model_output, timestep, sample) model_outputs_new = jnp.roll(state.model_outputs, -1, axis=0) model_outputs_new = model_outputs_new.at[-1].set(model_output) state = state.replace( model_outputs=model_outputs_new, prev_timestep=prev_timestep, cur_sample=sample, ) def step_1(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: return self.dpm_solver_first_order_update( state, state.model_outputs[-1], state.timesteps[step_index], state.prev_timestep, state.cur_sample, ) def step_23(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: timestep_list = jnp.array([state.timesteps[step_index - 1], state.timesteps[step_index]]) return self.multistep_dpm_solver_second_order_update( state, state.model_outputs, timestep_list, state.prev_timestep, state.cur_sample, ) def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: timestep_list = jnp.array( [ state.timesteps[step_index - 2], state.timesteps[step_index - 1], state.timesteps[step_index], ] ) return self.multistep_dpm_solver_third_order_update( state, state.model_outputs, timestep_list, state.prev_timestep, state.cur_sample, ) step_2_output = step_2(state) step_3_output = step_3(state) if self.config.solver_order == 2: return step_2_output elif self.config.lower_order_final and len(state.timesteps) < 15: return jax.lax.select( state.lower_order_nums < 2, step_2_output, jax.lax.select( step_index == len(state.timesteps) - 2, step_2_output, step_3_output, ), ) else: return jax.lax.select( state.lower_order_nums < 2, step_2_output, step_3_output, ) step_1_output = step_1(state) step_23_output = step_23(state) if self.config.solver_order == 1: prev_sample = step_1_output elif self.config.lower_order_final and len(state.timesteps) < 15: prev_sample = jax.lax.select( state.lower_order_nums < 1, step_1_output, jax.lax.select( step_index == len(state.timesteps) - 1, step_1_output, step_23_output, ), ) else: prev_sample = jax.lax.select( state.lower_order_nums < 1, step_1_output, step_23_output, ) state = state.replace( lower_order_nums=jnp.minimum(state.lower_order_nums + 1, self.config.solver_order), ) if not return_dict: return (prev_sample, state) return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) def scale_model_input( self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: state (`DPMSolverMultistepSchedulerState`): the `FlaxDPMSolverMultistepScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def add_noise( self, state: DPMSolverMultistepSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py ================================================ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): """ DPMSolverMultistepInverseScheduler is the reverse scheduler of [`DPMSolverMultistepScheduler`]. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. lambda_min_clipped (`float`, default `-inf`): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self.use_karras_sigmas = use_karras_sigmas def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped) self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx timesteps = ( np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) ) if self.use_karras_sigmas: sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = timesteps.copy().astype(np.int64) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned", "learned_range"]: epsilon = model_output[:, :3] else: epsilon = model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) if self.config.thresholding: alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * epsilon) / alpha_t x0_pred = self._threshold_sample(x0_pred) epsilon = (sample - alpha_t * x0_pred) / sigma_t return epsilon # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the first-order DPM-Solver (equivalent to DDIM). See https://arxiv.org/abs/2206.00927 for the detailed derivation. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None x_t = ( (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None x_t = ( (alpha_t / alpha_s) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the second-order multistep DPM-Solver. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) elif self.config.algorithm_type == "sde-dpmsolver++": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s0 * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise ) elif self.config.algorithm_type == "sde-dpmsolver": assert noise is not None if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * (torch.exp(h) - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s0) * sample - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise ) return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update def multistep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the third-order multistep DPM-Solver. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2], ) alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m0 D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations x_t = ( (alpha_t / alpha_s0) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the multistep DPM-Solver. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() prev_timestep = ( self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] ) lower_order_final = ( (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 ) lower_order_second = ( (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 ) model_output = self.convert_model_output(model_output, timestep, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype ) else: noise = None if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: prev_sample = self.dpm_solver_first_order_update( model_output, timestep, prev_timestep, sample, noise=noise ) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( self.model_outputs, timestep_list, prev_timestep, sample, noise=noise ) else: timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_third_order_update( self.model_outputs, timestep_list, prev_timestep, sample ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_dpmsolver_sde.py ================================================ # Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np import torch import torchsde from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput class BatchedBrownianTree: """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" def __init__(self, x, t0, t1, seed=None, **kwargs): t0, t1, self.sign = self.sort(t0, t1) w0 = kwargs.get("w0", torch.zeros_like(x)) if seed is None: seed = torch.randint(0, 2**63 - 1, []).item() self.batched = True try: assert len(seed) == x.shape[0] w0 = w0[0] except TypeError: seed = [seed] self.batched = False self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] @staticmethod def sort(a, b): return (a, b, 1) if a < b else (b, a, -1) def __call__(self, t0, t1): t0, t1, sign = self.sort(t0, t1) w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) return w if self.batched else w[0] class BrownianTreeNoiseSampler: """A noise sampler backed by a torchsde.BrownianTree. Args: x (Tensor): The tensor whose shape, device and dtype to use to generate random samples. sigma_min (float): The low end of the valid interval. sigma_max (float): The high end of the valid interval. seed (int or List[int]): The random seed. If a list of seeds is supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each with its own seed. transform (callable): A function that maps sigma to the sampler's internal timestep. """ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): self.transform = transform t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) self.tree = BatchedBrownianTree(x, t0, t1, seed) def __call__(self, sigma, sigma_next): t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) return self.tree(t0, t1) / (t1 - t0).abs().sqrt() # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): """ Implements Stochastic Sampler (Algorithm 2) from Karras et al. (2022). Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/41b4cb6df0506694a7776af31349acf082bf6091/k_diffusion/sampling.py#L543 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. noise_sampler_seed (`int`, *optional*, defaults to `None`): The random seed to use for the noise sampler. If `None`, a random seed will be generated. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, noise_sampler_seed: Optional[int] = None, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas self.noise_sampler = None self.noise_sampler_seed = noise_sampler_seed # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) if len(self._index_counter) == 0: pos = 1 if len(indices) > 1 else 0 else: timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep pos = self._index_counter[timestep_int] return indices[pos].item() @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], ) -> torch.FloatTensor: """ Args: Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ step_index = self.index_for_timestep(timestep) sigma = self.sigmas[step_index] sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma sample = sample / ((sigma_input**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) timesteps = torch.from_numpy(timesteps) second_order_timesteps = torch.from_numpy(second_order_timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) timesteps[1::2] = second_order_timesteps if str(device).startswith("mps"): # mps does not support float64 self.timesteps = timesteps.to(device, dtype=torch.float32) else: self.timesteps = timesteps.to(device=device) # empty first order variables self.sample = None self.mid_point_sigma = None # for exp beta schedules, such as the one for `pipeline_shap_e.py` # we need an index counter self._index_counter = defaultdict(int) def _second_order_timesteps(self, sigmas, log_sigmas): def sigma_fn(_t): return np.exp(-_t) def t_fn(_sigma): return -np.log(_sigma) midpoint_ratio = 0.5 t = t_fn(sigmas) delta_time = np.diff(t) t_proposed = t[:-1] + delta_time * midpoint_ratio sig_proposed = sigma_fn(t_proposed) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed]) return timesteps # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, self.num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas @property def state_in_first_order(self): return self.sample is None def step( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: Union[float, torch.FloatTensor], sample: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, s_noise: float = 1.0, ) -> Union[SchedulerOutput, Tuple]: """ Args: Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). model_output (Union[torch.FloatTensor, np.ndarray]): Direct output from learned diffusion model. timestep (Union[float, torch.FloatTensor]): Current discrete timestep in the diffusion chain. sample (Union[torch.FloatTensor, np.ndarray]): Current instance of sample being created by diffusion process. return_dict (bool, optional): Option for returning tuple rather than SchedulerOutput class. Defaults to True. s_noise (float, optional): Scaling factor for the noise added to the sample. Defaults to 1.0. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ step_index = self.index_for_timestep(timestep) # advance index counter by 1 timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep self._index_counter[timestep_int] += 1 # Create a noise sampler if it hasn't been created yet if self.noise_sampler is None: min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) # Define functions to compute sigma and t from each other def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor: return _t.neg().exp() def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: return _sigma.log().neg() if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_next = self.sigmas[step_index + 1] else: # 2nd order sigma = self.sigmas[step_index - 1] sigma_next = self.sigmas[step_index] # Set the midpoint and step size for the current step midpoint_ratio = 0.5 t, t_next = t_fn(sigma), t_fn(sigma_next) delta_time = t_next - t t_proposed = t + delta_time * midpoint_ratio # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if sigma_next == 0: derivative = (sample - pred_original_sample) / sigma dt = sigma_next - sigma prev_sample = sample + derivative * dt else: if self.state_in_first_order: t_next = t_proposed else: sample = self.sample sigma_from = sigma_fn(t) sigma_to = sigma_fn(t_next) sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 ancestral_t = t_fn(sigma_down) prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( t - ancestral_t ).expm1() * pred_original_sample prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up if self.state_in_first_order: # store for 2nd order step self.sample = sample self.mid_point_sigma = sigma_fn(t_next) else: # free for "first order mode" self.sample = None self.mid_point_sigma = None if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_dpmsolver_singlestep.py ================================================ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import logging from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): """ DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality samples, and it can generate quite good samples even in only 10 steps. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 Currently, we support the singlestep DPM-Solver for both noise prediction models and data prediction models. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`): indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, or `v-prediction`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are slightly better, so we recommend to use the `midpoint` type. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable this to use up all the function evaluations. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. lambda_min_clipped (`float`, default `-inf`): the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for cosine (squaredcos_cap_v2) noise schedule. variance_type (`str`, *optional*): Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on diffusion ODEs. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # settings for DPM-Solver if algorithm_type not in ["dpmsolver", "dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.sample = None self.order_list = self.get_order_list(num_train_timesteps) def get_order_list(self, num_inference_steps: int) -> List[int]: """ Computes the solver order at each time step. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ steps = num_inference_steps order = self.config.solver_order if self.config.lower_order_final: if order == 3: if steps % 3 == 0: orders = [1, 2, 3] * (steps // 3 - 1) + [1, 2] + [1] elif steps % 3 == 1: orders = [1, 2, 3] * (steps // 3) + [1] else: orders = [1, 2, 3] * (steps // 3) + [1, 2] elif order == 2: if steps % 2 == 0: orders = [1, 2] * (steps // 2) else: orders = [1, 2] * (steps // 2) + [1] elif order == 1: orders = [1] * steps else: if order == 3: orders = [1, 2, 3] * (steps // 3) elif order == 2: orders = [1, 2] * (steps // 2) elif order == 1: orders = [1] * steps return orders def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [None] * self.config.solver_order self.sample = None if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: logger.warn( "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." ) self.register_to_config(lower_order_final=True) self.order_list = self.get_order_list(num_inference_steps) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or DPM-Solver++ for both noise prediction model and data prediction model. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned_range"]: model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverSinglestepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": # DPM-Solver and DPM-Solver++ only need the "mean" output. if self.config.variance_type in ["learned_range"]: model_output = model_output[:, :3] return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverSinglestepScheduler." ) def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the first-order DPM-Solver (equivalent to DDIM). See https://arxiv.org/abs/2206.00927 for the detailed derivation. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] h = lambda_t - lambda_s if self.config.algorithm_type == "dpmsolver++": x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output return x_t def singlestep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the second-order singlestep DPM-Solver. It computes the solution at time `prev_timestep` from the time `timestep_list[-2]`. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] alpha_t, alpha_s1 = self.alpha_t[t], self.alpha_t[s1] sigma_t, sigma_s1 = self.sigma_t[t], self.sigma_t[s1] h, h_0 = lambda_t - lambda_s1, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m1, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s1) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s1) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s1) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s1) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) return x_t def singlestep_dpm_solver_third_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: """ One step for the third-order singlestep DPM-Solver. It computes the solution at time `prev_timestep` from the time `timestep_list[-3]`. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1], self.lambda_t[s2], ) alpha_t, alpha_s2 = self.alpha_t[t], self.alpha_t[s2] sigma_t, sigma_s2 = self.sigma_t[t], self.sigma_t[s2] h, h_0, h_1 = lambda_t - lambda_s2, lambda_s0 - lambda_s2, lambda_s1 - lambda_s2 r0, r1 = h_0 / h, h_1 / h D0 = m2 D1_0, D1_1 = (1.0 / r1) * (m1 - m2), (1.0 / r0) * (m0 - m2) D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1) D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (sigma_t / sigma_s2) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1_1 ) elif self.config.solver_type == "heun": x_t = ( (sigma_t / sigma_s2) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": x_t = ( (alpha_t / alpha_s2) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1_1 ) elif self.config.solver_type == "heun": x_t = ( (alpha_t / alpha_s2) * sample - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 ) return x_t def singlestep_dpm_solver_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, order: int, ) -> torch.FloatTensor: """ One step for the singlestep DPM-Solver. Args: model_output_list (`List[torch.FloatTensor]`): direct outputs from learned diffusion model at current and latter timesteps. timestep (`int`): current and latter discrete timestep in the diffusion chain. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. order (`int`): the solver order at this step. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ if order == 1: return self.dpm_solver_first_order_update(model_output_list[-1], timestep_list[-1], prev_timestep, sample) elif order == 2: return self.singlestep_dpm_solver_second_order_update( model_output_list, timestep_list, prev_timestep, sample ) elif order == 3: return self.singlestep_dpm_solver_third_order_update( model_output_list, timestep_list, prev_timestep, sample ) else: raise ValueError(f"Order must be 1, 2, 3, got {order}") def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the singlestep DPM-Solver. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] model_output = self.convert_model_output(model_output, timestep, sample) for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output order = self.order_list[step_index] # For img2img denoising might start with order>1 which is not possible # In this case make sure that the first two steps are both order=1 while self.model_outputs[-order] is None: order -= 1 # For single-step solvers, we use the initial value at each time with order = 1. if order == 1: self.sample = sample timestep_list = [self.timesteps[step_index - i] for i in range(order - 1, 0, -1)] + [timestep] prev_sample = self.singlestep_dpm_solver_update( self.model_outputs, timestep_list, prev_timestep, self.sample, order ) if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_euler_ancestral_discrete.py ================================================ # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete class EulerAncestralDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.FloatTensor`): input sample timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: scaled input sample """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ ::-1 ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator (`torch.Generator`, optional): Random number generator. return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class Returns: [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) sigma_from = self.sigmas[step_index] sigma_to = self.sigmas[step_index + 1] sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma dt = sigma_down - sigma prev_sample = sample + derivative * dt device = model_output.device noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) prev_sample = prev_sample + noise * sigma_up if not return_dict: return (prev_sample,) return EulerAncestralDiscreteSchedulerOutput( prev_sample=prev_sample, pred_original_sample=pred_original_sample ) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_euler_discrete.py ================================================ # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging, randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete class EulerDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `"epsilon"`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) interpolation_type (`str`, default `"linear"`, optional): interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of [`"linear"`, `"log_linear"`]. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.FloatTensor`): input sample timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: scaled input sample """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ ::-1 ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) if self.config.interpolation_type == "linear": sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) elif self.config.interpolation_type == "log_linear": sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp() else: raise ValueError( f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" " 'linear' or 'log_linear'" ) if self.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. s_churn (`float`) s_tmin (`float`) s_tmax (`float`) s_noise (`float`) generator (`torch.Generator`, optional): Random number generator. return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class Returns: [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) eps = noise * s_noise sigma_hat = sigma * (gamma + 1) if gamma > 0: sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # NOTE: "original_sample" should not be an expected prediction_type but is left in for # backwards compatibility if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma_hat * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma_hat dt = self.sigmas[step_index + 1] - sigma_hat prev_sample = sample + derivative * dt if not return_dict: return (prev_sample,) return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_heun_discrete.py ================================================ # Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf). clip_sample (`bool`, default `True`): option to clip predicted sample for numerical stability. clip_sample_range (`float`, default `1.0`): the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, clip_sample: Optional[bool] = False, clip_sample_range: float = 1.0, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine") elif beta_schedule == "exp": self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp") else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self.use_karras_sigmas = use_karras_sigmas def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) if len(self._index_counter) == 0: pos = 1 if len(indices) > 1 else 0 else: timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep pos = self._index_counter[timestep_int] return indices[pos].item() @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], ) -> torch.FloatTensor: """ Args: Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ step_index = self.index_for_timestep(timestep) sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) timesteps = torch.from_numpy(timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = timesteps.to(device, dtype=torch.float32) else: self.timesteps = timesteps.to(device=device) # empty dt and derivative self.prev_derivative = None self.dt = None # for exp beta schedules, such as the one for `pipeline_shap_e.py` # we need an index counter self._index_counter = defaultdict(int) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas @property def state_in_first_order(self): return self.dt is None def step( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: Union[float, torch.FloatTensor], sample: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Args: Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ step_index = self.index_for_timestep(timestep) # advance index counter by 1 timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep self._index_counter[timestep_int] += 1 if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_next = self.sigmas[step_index + 1] else: # 2nd order / Heun's method sigma = self.sigmas[step_index - 1] sigma_next = self.sigmas[step_index] # currently only gamma=0 is supported. This usually works best anyways. # We can support gamma in the future but then need to scale the timestep before # passing it to the model which requires a change in API gamma = 0 sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma_hat if self.state_in_first_order else sigma_next pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma_hat if self.state_in_first_order else sigma_next pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) if self.state_in_first_order: # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat # 3. delta timestep dt = sigma_next - sigma_hat # store for 2nd order step self.prev_derivative = derivative self.dt = dt self.sample = sample else: # 2. 2nd order / Heun's method derivative = (sample - pred_original_sample) / sigma_next derivative = (self.prev_derivative + derivative) / 2 # 3. take prev timestep & sample dt = self.dt sample = self.sample # free dt and derivative # Note, this puts the scheduler in "first order mode" self.prev_derivative = None self.dt = None self.sample = None prev_sample = sample + derivative * dt if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_ipndm.py ================================================ # Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin, SchedulerOutput class IPNDMScheduler(SchedulerMixin, ConfigMixin): """ Improved Pseudo numerical methods for diffusion models (iPNDM) ported from @crowsonkb's amazing k-diffusion [library](https://github.com/crowsonkb/v-diffusion-pytorch/blob/987f8985e38208345c1959b0ea767a625831cc9b/diffusion/sampling.py#L296) [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None ): # set `betas`, `alphas`, `timesteps` self.set_timesteps(num_train_timesteps) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 # running values self.ets = [] def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] steps = torch.cat([steps, torch.tensor([0.0])]) if self.config.trained_betas is not None: self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32) else: self.betas = torch.sin(steps * math.pi / 2) ** 2 self.alphas = (1.0 - self.betas**2) ** 0.5 timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] self.timesteps = timesteps.to(device) self.ets = [] def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) timestep_index = (self.timesteps == timestep).nonzero().item() prev_timestep_index = timestep_index + 1 ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index] self.ets.append(ets) if len(self.ets) == 1: ets = self.ets[-1] elif len(self.ets) == 2: ets = (3 * self.ets[-1] - self.ets[-2]) / 2 elif len(self.ets) == 3: ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 else: ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets) if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets): alpha = self.alphas[timestep_index] sigma = self.betas[timestep_index] next_alpha = self.alphas[prev_timestep_index] next_sigma = self.betas[prev_timestep_index] pred = (sample - sigma * ets) / max(alpha, 1e-8) prev_sample = next_alpha * pred + ets * next_sigma return prev_sample def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py ================================================ # Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) if len(self._index_counter) == 0: pos = 1 if len(indices) > 1 else 0 else: timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep pos = self._index_counter[timestep_int] return indices[pos].item() @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], ) -> torch.FloatTensor: """ Args: Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ step_index = self.index_for_timestep(timestep) if self.state_in_first_order: sigma = self.sigmas[step_index] else: sigma = self.sigmas_interpol[step_index - 1] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) # compute up and down sigmas sigmas_next = sigmas.roll(-1) sigmas_next[-1] = 0.0 sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5 sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5 sigmas_down[-1] = 0.0 # compute interpolated sigmas sigmas_interpol = sigmas.log().lerp(sigmas_down.log(), 0.5).exp() sigmas_interpol[-2:] = 0.0 # set sigmas self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) self.sigmas_interpol = torch.cat( [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] ) self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]]) self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: timesteps = torch.from_numpy(timesteps).to(device) timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.sample = None # for exp beta schedules, such as the one for `pipeline_shap_e.py` # we need an index counter self._index_counter = defaultdict(int) def sigma_to_t(self, sigma): # get log sigma log_sigma = sigma.log() # get distribution dists = log_sigma - self.log_sigmas[:, None] # get sigmas range low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = self.log_sigmas[low_idx] high = self.log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = w.clamp(0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.view(sigma.shape) return t @property def state_in_first_order(self): return self.sample is None def step( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: Union[float, torch.FloatTensor], sample: Union[torch.FloatTensor, np.ndarray], generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Args: Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ step_index = self.index_for_timestep(timestep) # advance index counter by 1 timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep self._index_counter[timestep_int] += 1 if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_interpol = self.sigmas_interpol[step_index] sigma_up = self.sigmas_up[step_index] sigma_down = self.sigmas_down[step_index - 1] else: # 2nd order / KPDM2's method sigma = self.sigmas[step_index - 1] sigma_interpol = self.sigmas_interpol[step_index - 1] sigma_up = self.sigmas_up[step_index - 1] sigma_down = self.sigmas_down[step_index - 1] # currently only gamma=0 is supported. This usually works best anyways. # We can support gamma in the future but then need to scale the timestep before # passing it to the model which requires a change in API gamma = 0 sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now device = model_output.device noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if self.state_in_first_order: # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat # 3. delta timestep dt = sigma_interpol - sigma_hat # store for 2nd order step self.sample = sample self.dt = dt prev_sample = sample + derivative * dt else: # DPM-Solver-2 # 2. Convert to an ODE derivative for 2nd order derivative = (sample - pred_original_sample) / sigma_interpol # 3. delta timestep dt = sigma_down - sigma_hat sample = self.sample self.sample = None prev_sample = sample + derivative * dt prev_sample = prev_sample + noise * sigma_up if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_k_dpm_2_discrete.py ================================================ # Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): """ Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) if len(self._index_counter) == 0: pos = 1 if len(indices) > 1 else 0 else: timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep pos = self._index_counter[timestep_int] return indices[pos].item() @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], ) -> torch.FloatTensor: """ Args: Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ step_index = self.index_for_timestep(timestep) if self.state_in_first_order: sigma = self.sigmas[step_index] else: sigma = self.sigmas_interpol[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None, num_train_timesteps: Optional[int] = None, ): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() elif self.config.timestep_spacing == "leading": step_ratio = num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) # interpolate sigmas sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp() self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]]) self.sigmas_interpol = torch.cat( [sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]] ) if str(device).startswith("mps"): # mps does not support float64 timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: timesteps = torch.from_numpy(timesteps).to(device) # interpolate timesteps timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.sample = None # for exp beta schedules, such as the one for `pipeline_shap_e.py` # we need an index counter self._index_counter = defaultdict(int) def sigma_to_t(self, sigma): # get log sigma log_sigma = sigma.log() # get distribution dists = log_sigma - self.log_sigmas[:, None] # get sigmas range low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = self.log_sigmas[low_idx] high = self.log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = w.clamp(0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.view(sigma.shape) return t @property def state_in_first_order(self): return self.sample is None def step( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: Union[float, torch.FloatTensor], sample: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Args: Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ step_index = self.index_for_timestep(timestep) # advance index counter by 1 timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep self._index_counter[timestep_int] += 1 if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_interpol = self.sigmas_interpol[step_index + 1] sigma_next = self.sigmas[step_index + 1] else: # 2nd order / KDPM2's method sigma = self.sigmas[step_index - 1] sigma_interpol = self.sigmas_interpol[step_index] sigma_next = self.sigmas[step_index] # currently only gamma=0 is supported. This usually works best anyways. # We can support gamma in the future but then need to scale the timestep before # passing it to the model which requires a change in API gamma = 0 sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": raise NotImplementedError("prediction_type not implemented yet: sample") else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) if self.state_in_first_order: # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat # 3. delta timestep dt = sigma_interpol - sigma_hat # store for 2nd order step self.sample = sample else: # DPM-Solver-2 # 2. Convert to an ODE derivative for 2nd order derivative = (sample - pred_original_sample) / sigma_interpol # 3. delta timestep dt = sigma_next - sigma_hat sample = self.sample self.sample = None prev_sample = sample + derivative * dt if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_karras_ve.py ================================================ # Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import SchedulerMixin @dataclass class KarrasVeOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Derivative of predicted original image sample (x_0). pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor derivative: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None class KarrasVeScheduler(SchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. Args: sigma_min (`float`): minimum noise magnitude sigma_max (`float`): maximum noise magnitude s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. s_churn (`float`): the parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100]. s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). A reasonable range is [0, 10]. s_max (`float`): the end value of the sigma range where we add noise. A reasonable range is [0.2, 80]. """ order = 2 @register_to_config def __init__( self, sigma_min: float = 0.02, sigma_max: float = 100, s_noise: float = 1.007, s_churn: float = 80, s_min: float = 0.05, s_max: float = 50, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max # setable values self.num_inference_steps: int = None self.timesteps: np.IntTensor = None self.schedule: torch.FloatTensor = None # sigma(t_i) def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = torch.from_numpy(timesteps).to(device) schedule = [ ( self.config.sigma_max**2 * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) ) for i in self.timesteps ] self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device) def add_noise_to_input( self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None ) -> Tuple[torch.FloatTensor, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. TODO Args: """ if self.config.s_min <= sigma <= self.config.s_max: gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) else: gamma = 0 # sample eps ~ N(0, S_noise^2 * I) eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device) sigma_hat = sigma + gamma * sigma sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) return sample_hat, sigma_hat def step( self, model_output: torch.FloatTensor, sigma_hat: float, sigma_prev: float, sample_hat: torch.FloatTensor, return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). Returns: [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ pred_original_sample = sample_hat + sigma_hat * model_output derivative = (sample_hat - pred_original_sample) / sigma_hat sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative if not return_dict: return (sample_prev, derivative) return KarrasVeOutput( prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample ) def step_correct( self, model_output: torch.FloatTensor, sigma_hat: float, sigma_prev: float, sample_hat: torch.FloatTensor, sample_prev: torch.FloatTensor, derivative: torch.FloatTensor, return_dict: bool = True, ) -> Union[KarrasVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. TODO complete description Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor`): TODO sample_prev (`torch.FloatTensor`): TODO derivative (`torch.FloatTensor`): TODO return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO """ pred_original_sample = sample_prev + sigma_prev * model_output derivative_corr = (sample_prev - pred_original_sample) / sigma_prev sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) if not return_dict: return (sample_prev, derivative) return KarrasVeOutput( prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample ) def add_noise(self, original_samples, noise, timesteps): raise NotImplementedError() ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_karras_ve_flax.py ================================================ # Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from jax import random from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils_flax import FlaxSchedulerMixin @flax.struct.dataclass class KarrasVeSchedulerState: # setable values num_inference_steps: Optional[int] = None timesteps: Optional[jnp.ndarray] = None schedule: Optional[jnp.ndarray] = None # sigma(t_i) @classmethod def create(cls): return cls() @dataclass class FlaxKarrasVeOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Derivative of predicted original image sample (x_0). state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. """ prev_sample: jnp.ndarray derivative: jnp.ndarray state: KarrasVeSchedulerState class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and the VE column of Table 1 from [1] for reference. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. Args: sigma_min (`float`): minimum noise magnitude sigma_max (`float`): maximum noise magnitude s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000, 1.011]. s_churn (`float`): the parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100]. s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). A reasonable range is [0, 10]. s_max (`float`): the end value of the sigma range where we add noise. A reasonable range is [0.2, 80]. """ @property def has_state(self): return True @register_to_config def __init__( self, sigma_min: float = 0.02, sigma_max: float = 100, s_noise: float = 1.007, s_churn: float = 80, s_min: float = 0.05, s_max: float = 50, ): pass def create_state(self): return KarrasVeSchedulerState.create() def set_timesteps( self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> KarrasVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() schedule = [ ( self.config.sigma_max**2 * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) ) for i in timesteps ] return state.replace( num_inference_steps=num_inference_steps, schedule=jnp.array(schedule, dtype=jnp.float32), timesteps=timesteps, ) def add_noise_to_input( self, state: KarrasVeSchedulerState, sample: jnp.ndarray, sigma: float, key: random.KeyArray, ) -> Tuple[jnp.ndarray, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. TODO Args: """ if self.config.s_min <= sigma <= self.config.s_max: gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1) else: gamma = 0 # sample eps ~ N(0, S_noise^2 * I) key = random.split(key, num=1) eps = self.config.s_noise * random.normal(key=key, shape=sample.shape) sigma_hat = sigma + gamma * sigma sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) return sample_hat, sigma_hat def step( self, state: KarrasVeSchedulerState, model_output: jnp.ndarray, sigma_hat: float, sigma_prev: float, sample_hat: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxKarrasVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class Returns: [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ pred_original_sample = sample_hat + sigma_hat * model_output derivative = (sample_hat - pred_original_sample) / sigma_hat sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative if not return_dict: return (sample_prev, derivative, state) return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) def step_correct( self, state: KarrasVeSchedulerState, model_output: jnp.ndarray, sigma_hat: float, sigma_prev: float, sample_hat: jnp.ndarray, sample_prev: jnp.ndarray, derivative: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxKarrasVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. TODO complete description Args: state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.FloatTensor` or `np.ndarray`): TODO return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO """ pred_original_sample = sample_prev + sigma_prev * model_output derivative_corr = (sample_prev - pred_original_sample) / sigma_prev sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) if not return_dict: return (sample_prev, derivative, state) return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) def add_noise(self, state: KarrasVeSchedulerState, original_samples, noise, timesteps): raise NotImplementedError() ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_lms_discrete.py ================================================ # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import warnings from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete class LMSDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None self.use_karras_sigmas = use_karras_sigmas self.set_timesteps(num_train_timesteps, None) self.derivatives = [] self.is_scale_input_called = False @property def init_noise_sigma(self): # standard deviation of the initial noise distribution if self.config.timestep_spacing in ["linspace", "trailing"]: return self.sigmas.max() return (self.sigmas.max() ** 2 + 1) ** 0.5 def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. Args: sample (`torch.FloatTensor`): input sample timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain Returns: `torch.FloatTensor`: scaled input sample """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) self.is_scale_input_called = True return sample def get_lms_coefficient(self, order, t, current_order): """ Compute a linear multistep coefficient. Args: order (TODO): t (TODO): current_order (TODO): """ def lms_derivative(tau): prod = 1.0 for k in range(order): if current_order == k: continue prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) return prod integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] return integrated_coeff def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[ ::-1 ].copy() elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(float) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) if str(device).startswith("mps"): # mps does not support float64 self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: self.timesteps = torch.from_numpy(timesteps).to(device=device) self.derivatives = [] # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma log_sigma = np.log(sigma) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 low = log_sigmas[low_idx] high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx t = t.reshape(sigma.shape) return t # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min: float = in_sigmas[-1].item() sigma_max: float = in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, self.num_inference_steps) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, order: int = 4, return_dict: bool = True, ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class Returns: [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if not self.is_scale_input_called: warnings.warn( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) elif self.config.prediction_type == "sample": pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma self.derivatives.append(derivative) if len(self.derivatives) > order: self.derivatives.pop(0) # 3. Compute linear multistep coefficients order = min(step_index + 1, order) lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path prev_sample = sample + sum( coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) ) if not return_dict: return (prev_sample,) return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_lms_discrete_flax.py ================================================ # Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, ) @flax.struct.dataclass class LMSDiscreteSchedulerState: common: CommonSchedulerState # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray sigmas: jnp.ndarray num_inference_steps: Optional[int] = None # running values derivatives: Optional[jnp.ndarray] = None @classmethod def create( cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray ): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) @dataclass class FlaxLMSSchedulerOutput(FlaxSchedulerOutput): state: LMSDiscreteSchedulerState class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): """ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState: if common is None: common = CommonSchedulerState.create(self) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5 # standard deviation of the initial noise distribution init_noise_sigma = sigmas.max() return LMSDiscreteSchedulerState.create( common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas, ) def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. Args: state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. timestep (`int`): current discrete timestep in the diffusion chain. Returns: `jnp.ndarray`: scaled input sample """ (step_index,) = jnp.where(state.timesteps == timestep, size=1) step_index = step_index[0] sigma = state.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order): """ Compute a linear multistep coefficient. Args: order (TODO): t (TODO): current_order (TODO): """ def lms_derivative(tau): prod = 1.0 for k in range(order): if current_order == k: continue prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k]) return prod integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0] return integrated_coeff def set_timesteps( self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) low_idx = jnp.floor(timesteps).astype(jnp.int32) high_idx = jnp.ceil(timesteps).astype(jnp.int32) frac = jnp.mod(timesteps, 1.0) sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5 sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)]) timesteps = timesteps.astype(jnp.int32) # initial running values derivatives = jnp.zeros((0,) + shape, dtype=self.dtype) return state.replace( timesteps=timesteps, sigmas=sigmas, num_inference_steps=num_inference_steps, derivatives=derivatives, ) def step( self, state: LMSDiscreteSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, order: int = 4, return_dict: bool = True, ) -> Union[FlaxLMSSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class Returns: [`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) sigma = state.sigmas[timestep] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": pred_original_sample = sample - sigma * model_output elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma state = state.replace(derivatives=jnp.append(state.derivatives, derivative)) if len(state.derivatives) > order: state = state.replace(derivatives=jnp.delete(state.derivatives, 0)) # 3. Compute linear multistep coefficients order = min(timestep + 1, order) lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path prev_sample = sample + sum( coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives)) ) if not return_dict: return (prev_sample, state) return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state) def add_noise( self, state: LMSDiscreteSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: sigma = state.sigmas[timesteps].flatten() sigma = broadcast_to_shape_from_left(sigma, noise.shape) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_pndm.py ================================================ # Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class PNDMScheduler(SchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. set_alpha_to_one (`bool`, default `False`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) timestep_spacing (`str`, default `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, prediction_type: str = "epsilon", timestep_spacing: str = "leading", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 # running values self.cur_model_output = 0 self.counter = 0 self.cur_sample = None self.ets = [] # setable values self.num_inference_steps = None self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.prk_timesteps = None self.plms_timesteps = None self.timesteps = None def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": self._timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() self._timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 self._timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio))[::-1].astype( np.int64 ) self._timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 self.prk_timesteps = np.array([]) self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[ ::-1 ].copy() else: prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order ) self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() self.plms_timesteps = self._timesteps[:-3][ ::-1 ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) self.ets = [] self.counter = 0 self.cur_model_output = 0 def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) else: return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) def step_prk( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 prev_timestep = timestep - diff_to_prev timestep = self.prk_timesteps[self.counter // 4 * 4] if self.counter % 4 == 0: self.cur_model_output += 1 / 6 * model_output self.ets.append(model_output) self.cur_sample = sample elif (self.counter - 1) % 4 == 0: self.cur_model_output += 1 / 3 * model_output elif (self.counter - 2) % 4 == 0: self.cur_model_output += 1 / 3 * model_output elif (self.counter - 3) % 4 == 0: model_output = self.cur_model_output + 1 / 6 * model_output self.cur_model_output = 0 # cur_sample should not be `None` cur_sample = self.cur_sample if self.cur_sample is not None else sample prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) self.counter += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def step_plms( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if not self.config.skip_prk_steps and len(self.ets) < 3: raise ValueError( f"{self.__class__} can only be run AFTER scheduler has been run " "in 'prk' mode for at least 12 iterations " "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " "for more information." ) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps if self.counter != 1: self.ets = self.ets[-3:] self.ets.append(model_output) else: prev_timestep = timestep timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps if len(self.ets) == 1 and self.counter == 0: model_output = model_output self.cur_sample = sample elif len(self.ets) == 1 and self.counter == 1: model_output = (model_output + self.ets[-1]) / 2 sample = self.cur_sample self.cur_sample = None elif len(self.ets) == 2: model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 elif len(self.ets) == 3: model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 else: model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) self.counter += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation # Notation ( -> # alpha_prod_t -> α_t # alpha_prod_t_prev -> α_(t−δ) # beta_prod_t -> (1 - α_t) # beta_prod_t_prev -> (1 - α_(t−δ)) # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if self.config.prediction_type == "v_prediction": model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample elif self.config.prediction_type != "epsilon": raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" ) # corresponds to (α_(t−δ) - α_t) divided by # denominator of x_t in formula (9) and plus 1 # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = # sqrt(α_(t−δ)) / sqrt(α_t)) sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) # corresponds to denominator of e_θ(x_t, t) in formula (9) model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( alpha_prod_t * beta_prod_t * alpha_prod_t_prev ) ** (0.5) # full formula (9) prev_sample = ( sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff ) return prev_sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_pndm_flax.py ================================================ # Copyright 2023 Zhejiang University Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, ) @flax.struct.dataclass class PNDMSchedulerState: common: CommonSchedulerState final_alpha_cumprod: jnp.ndarray # setable values init_noise_sigma: jnp.ndarray timesteps: jnp.ndarray num_inference_steps: Optional[int] = None prk_timesteps: Optional[jnp.ndarray] = None plms_timesteps: Optional[jnp.ndarray] = None # running values cur_model_output: Optional[jnp.ndarray] = None counter: Optional[jnp.int32] = None cur_sample: Optional[jnp.ndarray] = None ets: Optional[jnp.ndarray] = None @classmethod def create( cls, common: CommonSchedulerState, final_alpha_cumprod: jnp.ndarray, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, ): return cls( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) @dataclass class FlaxPNDMSchedulerOutput(FlaxSchedulerOutput): state: PNDMSchedulerState class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`jnp.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. set_alpha_to_one (`bool`, default `False`): each diffusion step uses the value of alphas product at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the value of alpha at step 0. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): the `dtype` used for params and computation. """ _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype pndm_order: int @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, steps_offset: int = 0, prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): self.dtype = dtype # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 def create_state(self, common: Optional[CommonSchedulerState] = None) -> PNDMSchedulerState: if common is None: common = CommonSchedulerState.create(self) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. final_alpha_cumprod = ( jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0] ) # standard deviation of the initial noise distribution init_noise_sigma = jnp.array(1.0, dtype=self.dtype) timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1] return PNDMSchedulerState.create( common=common, final_alpha_cumprod=final_alpha_cumprod, init_noise_sigma=init_noise_sigma, timesteps=timesteps, ) def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. shape (`Tuple`): the shape of the samples to be generated. """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + self.config.steps_offset if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 prk_timesteps = jnp.array([], dtype=jnp.int32) plms_timesteps = jnp.concatenate([_timesteps[:-1], _timesteps[-2:-1], _timesteps[-1:]])[::-1] else: prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), self.pndm_order, ) prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1] plms_timesteps = _timesteps[:-3][::-1] timesteps = jnp.concatenate([prk_timesteps, plms_timesteps]) # initial running values cur_model_output = jnp.zeros(shape, dtype=self.dtype) counter = jnp.int32(0) cur_sample = jnp.zeros(shape, dtype=self.dtype) ets = jnp.zeros((4,) + shape, dtype=self.dtype) return state.replace( timesteps=timesteps, num_inference_steps=num_inference_steps, prk_timesteps=prk_timesteps, plms_timesteps=plms_timesteps, cur_model_output=cur_model_output, counter=counter, cur_sample=cur_sample, ets=ets, ) def scale_model_input( self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. sample (`jnp.ndarray`): input sample timestep (`int`, optional): current timestep Returns: `jnp.ndarray`: scaled input sample """ return sample def step( self, state: PNDMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, return_dict: bool = True, ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if self.config.skip_prk_steps: prev_sample, state = self.step_plms(state, model_output, timestep, sample) else: prk_prev_sample, prk_state = self.step_prk(state, model_output, timestep, sample) plms_prev_sample, plms_state = self.step_plms(state, model_output, timestep, sample) cond = state.counter < len(state.prk_timesteps) prev_sample = jax.lax.select(cond, prk_prev_sample, plms_prev_sample) state = state.replace( cur_model_output=jax.lax.select(cond, prk_state.cur_model_output, plms_state.cur_model_output), ets=jax.lax.select(cond, prk_state.ets, plms_state.ets), cur_sample=jax.lax.select(cond, prk_state.cur_sample, plms_state.cur_sample), counter=jax.lax.select(cond, prk_state.counter, plms_state.counter), ) if not return_dict: return (prev_sample, state) return FlaxPNDMSchedulerOutput(prev_sample=prev_sample, state=state) def step_prk( self, state: PNDMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) diff_to_prev = jnp.where( state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 ) prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] model_output = jax.lax.select( (state.counter % 4) != 3, model_output, # remainder 0, 1, 2 state.cur_model_output + 1 / 6 * model_output, # remainder 3 ) state = state.replace( cur_model_output=jax.lax.select_n( state.counter % 4, state.cur_model_output + 1 / 6 * model_output, # remainder 0 state.cur_model_output + 1 / 3 * model_output, # remainder 1 state.cur_model_output + 1 / 3 * model_output, # remainder 2 jnp.zeros_like(state.cur_model_output), # remainder 3 ), ets=jax.lax.select( (state.counter % 4) == 0, state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # remainder 0 state.ets, # remainder 1, 2, 3 ), cur_sample=jax.lax.select( (state.counter % 4) == 0, sample, # remainder 0 state.cur_sample, # remainder 1, 2, 3 ), ) cur_sample = state.cur_sample prev_sample = self._get_prev_sample(state, cur_sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) return (prev_sample, state) def step_plms( self, state: PNDMSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, ) -> Union[FlaxPNDMSchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. Args: state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than FlaxPNDMSchedulerOutput class Returns: [`FlaxPNDMSchedulerOutput`] or `tuple`: [`FlaxPNDMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) # NOTE: There is no way to check in the jitted runtime if the prk mode was ran before prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) # Reference: # if state.counter != 1: # state.ets.append(model_output) # else: # prev_timestep = timestep # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) timestep = jnp.where( state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep ) # Reference: # if len(state.ets) == 1 and state.counter == 0: # model_output = model_output # state.cur_sample = sample # elif len(state.ets) == 1 and state.counter == 1: # model_output = (model_output + state.ets[-1]) / 2 # sample = state.cur_sample # state.cur_sample = None # elif len(state.ets) == 2: # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 # elif len(state.ets) == 3: # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 # else: # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]) state = state.replace( ets=jax.lax.select( state.counter != 1, state.ets.at[0:3].set(state.ets[1:4]).at[3].set(model_output), # counter != 1 state.ets, # counter 1 ), cur_sample=jax.lax.select( state.counter != 1, sample, # counter != 1 state.cur_sample, # counter 1 ), ) state = state.replace( cur_model_output=jax.lax.select_n( jnp.clip(state.counter, 0, 4), model_output, # counter 0 (model_output + state.ets[-1]) / 2, # counter 1 (3 * state.ets[-1] - state.ets[-2]) / 2, # counter 2 (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12, # counter 3 (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]), # counter >= 4 ), ) sample = state.cur_sample model_output = state.cur_model_output prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) return (prev_sample, state) def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation # Notation ( -> # alpha_prod_t -> α_t # alpha_prod_t_prev -> α_(t−δ) # beta_prod_t -> (1 - α_t) # beta_prod_t_prev -> (1 - α_(t−δ)) # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if self.config.prediction_type == "v_prediction": model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample elif self.config.prediction_type != "epsilon": raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `v_prediction`" ) # corresponds to (α_(t−δ) - α_t) divided by # denominator of x_t in formula (9) and plus 1 # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = # sqrt(α_(t−δ)) / sqrt(α_t)) sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) # corresponds to denominator of e_θ(x_t, t) in formula (9) model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( alpha_prod_t * beta_prod_t * alpha_prod_t_prev ) ** (0.5) # full formula (9) prev_sample = ( sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff ) return prev_sample def add_noise( self, state: PNDMSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: return add_noise_common(state.common, original_samples, noise, timesteps) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_repaint.py ================================================ # Copyright 2023 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import SchedulerMixin @dataclass class RePaintSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: torch.FloatTensor # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class RePaintScheduler(SchedulerMixin, ConfigMixin): """ RePaint is a schedule for DDPM inpainting inside a given mask. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. eta (`float`): The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and 1.0 is DDPM scheduler respectively. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", eta: float = 0.0, trained_betas: Optional[np.ndarray] = None, clip_sample: bool = True, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) self.final_alpha_cumprod = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.eta = eta def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps( self, num_inference_steps: int, jump_length: int = 10, jump_n_sample: int = 10, device: Union[str, torch.device] = None, ): num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) self.num_inference_steps = num_inference_steps timesteps = [] jumps = {} for j in range(0, num_inference_steps - jump_length, jump_length): jumps[j] = jump_n_sample - 1 t = num_inference_steps while t >= 1: t = t - 1 timesteps.append(t) if jumps.get(t, 0) > 0: jumps[t] = jumps[t] - 1 for _ in range(jump_length): t = t + 1 timesteps.append(t) timesteps = np.array(timesteps) * (self.config.num_train_timesteps // self.num_inference_steps) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t): prev_timestep = t - self.config.num_train_timesteps // self.num_inference_steps alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get # previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add # variance to pred_sample # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf # without eta. # variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) return variance def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, original_image: torch.FloatTensor, mask: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[RePaintSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. original_image (`torch.FloatTensor`): the original image to inpaint on. mask (`torch.FloatTensor`): the mask where 0.0 values define which part of the original image to inpaint (change). generator (`torch.Generator`, *optional*): random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class Returns: [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.RePaintSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 # 3. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = torch.clamp(pred_original_sample, -1, 1) # We choose to follow RePaint Algorithm 1 to get x_{t-1}, however we # substitute formula (7) in the algorithm coming from DDPM paper # (formula (4) Algorithm 2 - Sampling) with formula (12) from DDIM paper. # DDIM schedule gives the same results as DDPM with eta = 1.0 # Noise is being reused in 7. and 8., but no impact on quality has # been observed. # 5. Add noise device = model_output.device noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype) std_dev_t = self.eta * self._get_variance(timestep) ** 0.5 variance = 0 if t > 0 and self.eta > 0: variance = std_dev_t * noise # 6. compute "direction pointing to x_t" of formula (12) # from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part if not return_dict: return ( pred_prev_sample, pred_original_sample, ) return RePaintSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def undo_step(self, sample, timestep, generator=None): n = self.config.num_train_timesteps // self.num_inference_steps for i in range(n): beta = self.betas[timestep + i] if sample.device.type == "mps": # randn does not work reproducibly on mps noise = randn_tensor(sample.shape, dtype=sample.dtype, generator=generator) noise = noise.to(sample.device) else: noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise return sample def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.") def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_sde_ve.py ================================================ # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch import math from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import SchedulerMixin, SchedulerOutput @dataclass class SdeVeOutput(BaseOutput): """ Output class for the ScoreSdeVeScheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. """ prev_sample: torch.FloatTensor prev_sample_mean: torch.FloatTensor class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. For more information, see the original paper: https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. snr (`float`): coefficient weighting the step from the model_output sample (from the network) to the random noise. sigma_min (`float`): initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the distribution of the data. sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to epsilon. correct_steps (`int`): number of correction steps performed on a produced sample. """ order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 2000, snr: float = 0.15, sigma_min: float = 0.01, sigma_max: float = 1348.0, sampling_eps: float = 1e-5, correct_steps: int = 1, ): # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max # setable values self.timesteps = None self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps( self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None ): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). """ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device) def set_sigmas( self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None ): """ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. The sigmas control the weight of the `drift` and `diffusion` components of sample update. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. sigma_min (`float`, optional): initial noise scale value (overrides value given at Scheduler instantiation). sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). """ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps if self.timesteps is None: self.set_timesteps(num_inference_steps, sampling_eps) self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps) self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps)) self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) def get_adjacent_sigma(self, timesteps, t): return torch.where( timesteps == 0, torch.zeros_like(t.to(timesteps.device)), self.discrete_sigmas[timesteps - 1].to(timesteps.device), ) def step_pred( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SdeVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) timestep = timestep * torch.ones( sample.shape[0], device=sample.device ) # torch.repeat_interleave(timestep, sample.shape[0]) timesteps = (timestep * (len(self.timesteps) - 1)).long() # mps requires indices to be in the same device, so we use cpu as is the default with cuda timesteps = timesteps.to(self.discrete_sigmas.device) sigma = self.discrete_sigmas[timesteps].to(sample.device) adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) drift = torch.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods diffusion = diffusion.flatten() while len(diffusion.shape) < len(sample.shape): diffusion = diffusion.unsqueeze(-1) drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of noise = randn_tensor( sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype ) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean) return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean) def step_correct( self, model_output: torch.FloatTensor, sample: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator).to(sample.device) # compute step size from the model_output, the noise, and the snr grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * torch.ones(sample.shape[0]).to(sample.device) # self.repeat_scalar(step_size, sample.shape[0]) # compute corrected sample: model_output term and noise term step_size = step_size.flatten() while len(step_size.shape) < len(sample.shape): step_size = step_size.unsqueeze(-1) prev_sample_mean = sample + step_size * model_output prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples timesteps = timesteps.to(original_samples.device) sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps] noise = ( noise * sigmas[:, None, None, None] if noise is not None else torch.randn_like(original_samples) * sigmas[:, None, None, None] ) noisy_samples = noise + original_samples return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_sde_ve_flax.py ================================================ # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch from dataclasses import dataclass from typing import Optional, Tuple, Union import flax import jax.numpy as jnp from jax import random from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left @flax.struct.dataclass class ScoreSdeVeSchedulerState: # setable values timesteps: Optional[jnp.ndarray] = None discrete_sigmas: Optional[jnp.ndarray] = None sigmas: Optional[jnp.ndarray] = None @classmethod def create(cls): return cls() @dataclass class FlaxSdeVeOutput(FlaxSchedulerOutput): """ Output class for the ScoreSdeVeScheduler's step function output. Args: state (`ScoreSdeVeSchedulerState`): prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. """ state: ScoreSdeVeSchedulerState prev_sample: jnp.ndarray prev_sample_mean: Optional[jnp.ndarray] = None class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): """ The variance exploding stochastic differential equation (SDE) scheduler. For more information, see the original paper: https://arxiv.org/abs/2011.13456 [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. snr (`float`): coefficient weighting the step from the model_output sample (from the network) to the random noise. sigma_min (`float`): initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the distribution of the data. sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to epsilon. correct_steps (`int`): number of correction steps performed on a produced sample. """ @property def has_state(self): return True @register_to_config def __init__( self, num_train_timesteps: int = 2000, snr: float = 0.15, sigma_min: float = 0.01, sigma_max: float = 1348.0, sampling_eps: float = 1e-5, correct_steps: int = 1, ): pass def create_state(self): state = ScoreSdeVeSchedulerState.create() return self.set_sigmas( state, self.config.num_train_timesteps, self.config.sigma_min, self.config.sigma_max, self.config.sampling_eps, ) def set_timesteps( self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). """ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps timesteps = jnp.linspace(1, sampling_eps, num_inference_steps) return state.replace(timesteps=timesteps) def set_sigmas( self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None, ) -> ScoreSdeVeSchedulerState: """ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. The sigmas control the weight of the `drift` and `diffusion` components of sample update. Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. sigma_min (`float`, optional): initial noise scale value (overrides value given at Scheduler instantiation). sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). """ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps if state.timesteps is None: state = self.set_timesteps(state, num_inference_steps, sampling_eps) discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps)) sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps]) return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas) def get_adjacent_sigma(self, state, timesteps, t): return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1]) def step_pred( self, state: ScoreSdeVeSchedulerState, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, key: random.KeyArray, return_dict: bool = True, ) -> Union[FlaxSdeVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class Returns: [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.timesteps is None: raise ValueError( "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) timestep = timestep * jnp.ones( sample.shape[0], ) timesteps = (timestep * (len(state.timesteps) - 1)).long() sigma = state.discrete_sigmas[timesteps] adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep) drift = jnp.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods diffusion = diffusion.flatten() diffusion = broadcast_to_shape_from_left(diffusion, sample.shape) drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of key = random.split(key, num=1) noise = random.normal(key=key, shape=sample.shape) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean, state) return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state) def step_correct( self, state: ScoreSdeVeSchedulerState, model_output: jnp.ndarray, sample: jnp.ndarray, key: random.KeyArray, return_dict: bool = True, ) -> Union[FlaxSdeVeOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. Args: state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. model_output (`jnp.ndarray`): direct output from learned diffusion model. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than FlaxSdeVeOutput class Returns: [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if state.timesteps is None: raise ValueError( "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" # sample noise for correction key = random.split(key, num=1) noise = random.normal(key=key, shape=sample.shape) # compute step size from the model_output, the noise, and the snr grad_norm = jnp.linalg.norm(model_output) noise_norm = jnp.linalg.norm(noise) step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 step_size = step_size * jnp.ones(sample.shape[0]) # compute corrected sample: model_output term and noise term step_size = step_size.flatten() step_size = broadcast_to_shape_from_left(step_size, sample.shape) prev_sample_mean = sample + step_size * model_output prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample, state) return FlaxSdeVeOutput(prev_sample=prev_sample, state=state) def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_sde_vp.py ================================================ # Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch import math from typing import Union import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import randn_tensor from .scheduling_utils import SchedulerMixin class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): """ The variance preserving stochastic differential equation (SDE) scheduler. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more information, see the original paper: https://arxiv.org/abs/2011.13456 UNDER CONSTRUCTION """ order = 1 @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): self.sigmas = None self.discrete_sigmas = None self.timesteps = None def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None): self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device) def step_pred(self, score, x, t, generator=None): if self.timesteps is None: raise ValueError( "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" ) # TODO(Patrick) better comments + non-PyTorch # postprocess model score log_mean_coeff = ( -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min ) std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) std = std.flatten() while len(std.shape) < len(score.shape): std = std.unsqueeze(-1) score = -score / std # compute dt = -1.0 / len(self.timesteps) beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) beta_t = beta_t.flatten() while len(beta_t.shape) < len(x.shape): beta_t = beta_t.unsqueeze(-1) drift = -0.5 * beta_t * x diffusion = torch.sqrt(beta_t) drift = drift - diffusion**2 * score x_mean = x + drift * dt # add noise noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype) x = x_mean + diffusion * math.sqrt(-dt) * noise return x, x_mean def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_unclip.py ================================================ # Copyright 2023 Kakao Brain and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, randn_tensor from .scheduling_utils import SchedulerMixin @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP class UnCLIPSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample (x_{0}) based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ if alpha_transform_type == "cosine": def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 elif alpha_transform_type == "exp": def alpha_bar_fn(t): return math.exp(t * -12.0) else: raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class UnCLIPScheduler(SchedulerMixin, ConfigMixin): """ NOTE: do not use this scheduler. The DDPM scheduler has been updated to support the changes made here. This scheduler will be removed and replaced with DDPM. This is a modified DDPM Scheduler specifically for the karlo unCLIP model. This scheduler has some minor variations in how it calculates the learned range variance and dynamically re-calculates betas based off the timesteps it is skipping. The scheduler also uses a slightly different step ratio when computing timesteps to use for inference. See [`~DDPMScheduler`] for more information on DDPM scheduling Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small_log` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between `-clip_sample_range` and `clip_sample_range` for numerical stability. clip_sample_range (`float`, default `1.0`): The range to clip the sample between. See `clip_sample`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process) or `sample` (directly predicting the noisy sample`) """ @register_to_config def __init__( self, num_train_timesteps: int = 1000, variance_type: str = "fixed_small_log", clip_sample: bool = True, clip_sample_range: Optional[float] = 1.0, prediction_type: str = "epsilon", beta_schedule: str = "squaredcos_cap_v2", ): if beta_schedule != "squaredcos_cap_v2": raise ValueError("UnCLIPScheduler only supports `beta_schedule`: 'squaredcos_cap_v2'") self.betas = betas_for_alpha_bar(num_train_timesteps) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Note that this scheduler uses a slightly different step ratio than the other diffusers schedulers. The different step ratio is to mimic the original karlo implementation and does not affect the quality or accuracy of the results. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps step_ratio = (self.config.num_train_timesteps - 1) / (self.num_inference_steps - 1) timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) self.timesteps = torch.from_numpy(timesteps).to(device) def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance_type=None): if prev_timestep is None: prev_timestep = t - 1 alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if prev_timestep == t - 1: beta = self.betas[t] else: beta = 1 - alpha_prod_t / alpha_prod_t_prev # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = beta_prod_t_prev / beta_prod_t * beta if variance_type is None: variance_type = self.config.variance_type # hacks - were probably added for training stability if variance_type == "fixed_small_log": variance = torch.log(torch.clamp(variance, min=1e-20)) variance = torch.exp(0.5 * variance) elif variance_type == "learned_range": # NOTE difference with DDPM scheduler min_log = variance.log() max_log = beta.log() frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, prev_timestep: Optional[int] = None, generator=None, return_dict: bool = True, ) -> Union[UnCLIPSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at. Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used. generator: random number generator. return_dict (`bool`): option for returning tuple rather than UnCLIPSchedulerOutput class Returns: [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range": model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas if prev_timestep is None: prev_timestep = t - 1 alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev if prev_timestep == t - 1: beta = self.betas[t] alpha = self.alphas[t] else: beta = 1 - alpha_prod_t / alpha_prod_t_prev alpha = 1 - beta # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `sample`" " for the UnCLIPScheduler." ) # 3. Clip "predicted x_0" if self.config.clip_sample: pred_original_sample = torch.clamp( pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample µ_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise variance = 0 if t > 0: variance_noise = randn_tensor( model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device ) variance = self._get_variance( t, predicted_variance=predicted_variance, prev_timestep=prev_timestep, ) if self.variance_type == "fixed_small_log": variance = variance elif self.variance_type == "learned_range": variance = (0.5 * variance).exp() else: raise ValueError( f"variance_type given as {self.variance_type} must be one of `fixed_small_log` or `learned_range`" " for the UnCLIPScheduler." ) variance = variance * variance_noise pred_prev_sample = pred_prev_sample + variance if not return_dict: return (pred_prev_sample,) return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_unipc_multistep.py ================================================ # Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ UniPC is a training-free framework designed for the fast sampling of diffusion models, which consists of a corrector (UniC) and a predictor (UniP) that share a unified analytical form and support arbitrary orders. UniPC is by desinged model-agnostic, supporting pixel-space/latent-space DPMs on unconditional/conditional sampling. It can also be applied to both noise prediction model and data prediction model. The corrector UniC can be also applied after any off-the-shelf solvers to increase the order of accuracy. For more details, see the original paper: https://arxiv.org/abs/2302.04867 Currently, we support the multistep UniPC for both noise prediction models and data prediction models. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. solver_order (`int`, default `2`): the order of UniPC, also the p in UniPC-p; can be any positive integer. Note that the effective order of accuracy is `solver_order + 1` due to the UniC. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `predict_x0=True` and `thresholding=True` to use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). dynamic_thresholding_ratio (`float`, default `0.995`): the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://arxiv.org/abs/2205.11487). sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. predict_x0 (`bool`, default `True`): whether to use the updating algrithm on the predicted x0. See https://arxiv.org/abs/2211.01095 for details solver_type (`str`, default `bh2`): the solver type of UniPC. We recommend use `bh1` for unconditional sampling when steps < 10, and use `bh2` otherwise. lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. disable_corrector (`list`, default `[]`): decide which step to disable the corrector. For large guidance scale, the misalignment between the `epsilon_theta(x_t, c)`and `epsilon_theta(x_t^c, c)` might influence the convergence. This can be mitigated by disable the corrector at the first few steps (e.g., disable_corrector=[0]) solver_p (`SchedulerMixin`, default `None`): can be any other scheduler. If specified, the algorithm will become solver_p + UniC. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. timestep_spacing (`str`, default `"linspace"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. steps_offset (`int`, default `0`): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config def __init__( self, num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, predict_x0: bool = True, solver_type: str = "bh2", lower_order_final: bool = True, disable_corrector: List[int] = [], solver_p: SchedulerMixin = None, use_karras_sigmas: Optional[bool] = False, timestep_spacing: str = "linspace", steps_offset: int = 0, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # Currently we only support VP-type noise schedule self.alpha_t = torch.sqrt(self.alphas_cumprod) self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: self.register_to_config(solver_type="bh2") else: raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") self.predict_x0 = predict_x0 # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.timestep_list = [None] * solver_order self.lower_order_nums = 0 self.disable_corrector = disable_corrector self.solver_p = solver_p self.last_sample = None def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) timesteps += self.config.steps_offset elif self.config.timestep_spacing == "trailing": step_ratio = self.config.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) timesteps -= 1 else: raise ValueError( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() timesteps = np.flip(timesteps).copy().astype(np.int64) self.sigmas = torch.from_numpy(sigmas) # when num_inference_steps == num_train_timesteps, we can end up with # duplicates in timesteps. _, unique_indices = np.unique(timesteps, return_index=True) timesteps = timesteps[np.sort(unique_indices)] self.timesteps = torch.from_numpy(timesteps).to(device) self.num_inference_steps = len(timesteps) self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 self.last_sample = None if self.solver_p: self.solver_p.set_timesteps(self.num_inference_steps, device=device) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: """ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: r""" Convert the model output to the corresponding type that the algorithm PC needs. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. Returns: `torch.FloatTensor`: the converted model output. """ if self.predict_x0: if self.config.prediction_type == "epsilon": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": x0_pred = model_output elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = alpha_t * sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the UniPCMultistepScheduler." ) if self.config.thresholding: x0_pred = self._threshold_sample(x0_pred) return x0_pred else: if self.config.prediction_type == "epsilon": return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the UniPCMultistepScheduler." ) def multistep_uni_p_bh_update( self, model_output: torch.FloatTensor, prev_timestep: int, sample: torch.FloatTensor, order: int, ) -> torch.FloatTensor: """ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. Args: model_output (`torch.FloatTensor`): direct outputs from learned diffusion model at the current timestep. prev_timestep (`int`): previous discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. order (`int`): the order of UniP at this step, also the p in UniPC-p. Returns: `torch.FloatTensor`: the sample tensor at the previous timestep. """ timestep_list = self.timestep_list model_output_list = self.model_outputs s0, t = self.timestep_list[-1], prev_timestep m0 = model_output_list[-1] x = sample if self.solver_p: x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s0 device = sample.device rks = [] D1s = [] for i in range(1, order): si = timestep_list[-(i + 1)] mi = model_output_list[-(i + 1)] lambda_si = self.lambda_t[si] rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) # for order 2, we use a simplified version if order == 2: rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) else: D1s = None if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - alpha_t * B_h * pred_res else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - sigma_t * B_h * pred_res x_t = x_t.to(x.dtype) return x_t def multistep_uni_c_bh_update( self, this_model_output: torch.FloatTensor, this_timestep: int, last_sample: torch.FloatTensor, this_sample: torch.FloatTensor, order: int, ) -> torch.FloatTensor: """ One step for the UniC (B(h) version). Args: this_model_output (`torch.FloatTensor`): the model outputs at `x_t` this_timestep (`int`): the current timestep `t` last_sample (`torch.FloatTensor`): the generated sample before the last predictor: `x_{t-1}` this_sample (`torch.FloatTensor`): the generated sample after the last predictor: `x_{t}` order (`int`): the `p` of UniC-p at this step. Note that the effective order of accuracy should be order + 1 Returns: `torch.FloatTensor`: the corrected sample tensor at the current timestep. """ timestep_list = self.timestep_list model_output_list = self.model_outputs s0, t = timestep_list[-1], this_timestep m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s0 device = this_sample.device rks = [] D1s = [] for i in range(1, order): si = timestep_list[-(i + 1)] mi = model_output_list[-(i + 1)] lambda_si = self.lambda_t[si] rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) else: D1s = None # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_c = torch.linalg.solve(R, b) if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) x_t = x_t.to(x.dtype) return x_t def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the multistep UniPC. Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero() if len(step_index) == 0: step_index = len(self.timesteps) - 1 else: step_index = step_index.item() use_corrector = ( step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None ) model_output_convert = self.convert_model_output(model_output, timestep, sample) if use_corrector: sample = self.multistep_uni_c_bh_update( this_model_output=model_output_convert, this_timestep=timestep, last_sample=self.last_sample, this_sample=sample, order=self.this_order, ) # now prepare to run the predictor prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] for i in range(self.config.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] self.model_outputs[-1] = model_output_convert self.timestep_list[-1] = timestep if self.config.lower_order_final: this_order = min(self.config.solver_order, len(self.timesteps) - step_index) else: this_order = self.config.solver_order self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep assert self.this_order > 0 self.last_sample = sample prev_sample = self.multistep_uni_p_bh_update( model_output=model_output, # pass the original non-converted model output, in case solver-p is used prev_timestep=prev_timestep, sample=sample, order=self.this_order, ) if self.lower_order_nums < self.config.solver_order: self.lower_order_nums += 1 if not return_dict: return (prev_sample,) return SchedulerOutput(prev_sample=prev_sample) def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Args: sample (`torch.FloatTensor`): input sample Returns: `torch.FloatTensor`: scaled input sample """ return sample # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): return self.config.num_train_timesteps ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import os from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, Union import torch from ..utils import BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" # NOTE: We make this type an enum because it simplifies usage in docs and prevents # circular imports when used for `_compatibles` within the schedulers module. # When it's used as a type in pipelines, it really is a Union because the actual # scheduler instance is passed in. class KarrasDiffusionSchedulers(Enum): DDIMScheduler = 1 DDPMScheduler = 2 PNDMScheduler = 3 LMSDiscreteScheduler = 4 EulerDiscreteScheduler = 5 HeunDiscreteScheduler = 6 EulerAncestralDiscreteScheduler = 7 DPMSolverMultistepScheduler = 8 DPMSolverSinglestepScheduler = 9 KDPM2DiscreteScheduler = 10 KDPM2AncestralDiscreteScheduler = 11 DEISMultistepScheduler = 12 UniPCMultistepScheduler = 13 DPMSolverSDEScheduler = 14 @dataclass class SchedulerOutput(BaseOutput): """ Base class for the scheduler's step function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.FloatTensor class SchedulerMixin: """ Mixin containing common functions for the schedulers. Class attributes: - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that `from_config` can be used from a class different than the one used to save the config (should be overridden by parent class). """ config_name = SCHEDULER_CONFIG_NAME _compatibles = [] has_compatibles = True @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Dict[str, Any] = None, subfolder: Optional[str] = None, return_unused_kwargs=False, **kwargs, ): r""" Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - A path to a *directory* containing the schedluer configurations saved using [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. subfolder (`str`, *optional*): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether kwargs that are not consumed by the Python class should be returned or not. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a firewalled environment. """ config, kwargs, commit_hash = cls.load_config( pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, return_unused_kwargs=True, return_commit_hash=True, **kwargs, ) return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the [`~SchedulerMixin.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). """ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) @property def compatibles(self): """ Returns all schedulers that are compatible with this scheduler Returns: `List[SchedulerMixin]`: List of compatible schedulers """ return self._get_compatibles() @classmethod def _get_compatibles(cls): compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) diffusers_library = importlib.import_module(__name__.split(".")[0]) compatible_classes = [ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_utils_flax.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import math import os from dataclasses import dataclass from enum import Enum from typing import Any, Dict, Optional, Tuple, Union import flax import jax.numpy as jnp from ..utils import BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" # NOTE: We make this type an enum because it simplifies usage in docs and prevents # circular imports when used for `_compatibles` within the schedulers module. # When it's used as a type in pipelines, it really is a Union because the actual # scheduler instance is passed in. class FlaxKarrasDiffusionSchedulers(Enum): FlaxDDIMScheduler = 1 FlaxDDPMScheduler = 2 FlaxPNDMScheduler = 3 FlaxLMSDiscreteScheduler = 4 FlaxDPMSolverMultistepScheduler = 5 @dataclass class FlaxSchedulerOutput(BaseOutput): """ Base class for the scheduler's step function output. Args: prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: jnp.ndarray class FlaxSchedulerMixin: """ Mixin containing common functions for the schedulers. Class attributes: - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that `from_config` can be used from a class different than the one used to save the config (should be overridden by parent class). """ config_name = SCHEDULER_CONFIG_NAME ignore_for_config = ["dtype"] _compatibles = [] has_compatibles = True @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Dict[str, Any] = None, subfolder: Optional[str] = None, return_unused_kwargs=False, **kwargs, ): r""" Instantiate a Scheduler class from a pre-defined JSON-file. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. subfolder (`str`, *optional*): In case the relevant files are located inside a subfolder of the model repo (either remote in huggingface.co or downloaded locally), you can specify the folder name here. return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether kwargs that are not consumed by the Python class should be returned or not. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a firewalled environment. """ config, kwargs = cls.load_config( pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, return_unused_kwargs=True, **kwargs, ) scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False): state = scheduler.create_state() if return_unused_kwargs: return scheduler, state, unused_kwargs return scheduler, state def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the [`~FlaxSchedulerMixin.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). """ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) @property def compatibles(self): """ Returns all schedulers that are compatible with this scheduler Returns: `List[SchedulerMixin]`: List of compatible schedulers """ return self._get_compatibles() @classmethod def _get_compatibles(cls): compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) diffusers_library = importlib.import_module(__name__.split(".")[0]) compatible_classes = [ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: assert len(shape) >= x.ndim return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape) def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999, dtype=jnp.float32) -> jnp.ndarray: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up to that part of the diffusion process. Args: num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. Returns: betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return jnp.array(betas, dtype=dtype) @flax.struct.dataclass class CommonSchedulerState: alphas: jnp.ndarray betas: jnp.ndarray alphas_cumprod: jnp.ndarray @classmethod def create(cls, scheduler): config = scheduler.config if config.trained_betas is not None: betas = jnp.asarray(config.trained_betas, dtype=scheduler.dtype) elif config.beta_schedule == "linear": betas = jnp.linspace(config.beta_start, config.beta_end, config.num_train_timesteps, dtype=scheduler.dtype) elif config.beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. betas = ( jnp.linspace( config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype ) ** 2 ) elif config.beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule betas = betas_for_alpha_bar(config.num_train_timesteps, dtype=scheduler.dtype) else: raise NotImplementedError( f"beta_schedule {config.beta_schedule} is not implemented for scheduler {scheduler.__class__.__name__}" ) alphas = 1.0 - betas alphas_cumprod = jnp.cumprod(alphas, axis=0) return cls( alphas=alphas, betas=betas, alphas_cumprod=alphas_cumprod, ) def get_sqrt_alpha_prod( state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray ): alphas_cumprod = state.alphas_cumprod sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = sqrt_alpha_prod.flatten() sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) return sqrt_alpha_prod, sqrt_one_minus_alpha_prod def add_noise_common( state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray ): sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps) noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray): sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps) velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity ================================================ FILE: 6DoF/diffusers/schedulers/scheduling_vq_diffusion.py ================================================ # Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput from .scheduling_utils import SchedulerMixin @dataclass class VQDiffusionSchedulerOutput(BaseOutput): """ Output class for the scheduler's step function output. Args: prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`): Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.LongTensor def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor: """ Convert batch of vector of class indices into batch of log onehot vectors Args: x (`torch.LongTensor` of shape `(batch size, vector length)`): Batch of class indices num_classes (`int`): number of classes to be used for the onehot vectors Returns: `torch.FloatTensor` of shape `(batch size, num classes, vector length)`: Log onehot vectors """ x_onehot = F.one_hot(x, num_classes) x_onehot = x_onehot.permute(0, 2, 1) log_x = torch.log(x_onehot.float().clamp(min=1e-30)) return log_x def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor: """ Apply gumbel noise to `logits` """ uniform = torch.rand(logits.shape, device=logits.device, generator=generator) gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) noised = gumbel_noise + logits return noised def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009): """ Cumulative and non-cumulative alpha schedules. See section 4.1. """ att = ( np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start) + alpha_cum_start ) att = np.concatenate(([1], att)) at = att[1:] / att[:-1] att = np.concatenate((att[1:], [1])) return at, att def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999): """ Cumulative and non-cumulative gamma schedules. See section 4.1. """ ctt = ( np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start) + gamma_cum_start ) ctt = np.concatenate(([0], ctt)) one_minus_ctt = 1 - ctt one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] ct = 1 - one_minus_ct ctt = np.concatenate((ctt[1:], [0])) return ct, ctt class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): """ The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image. The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous diffusion timestep. [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2111.14822 Args: num_vec_classes (`int`): The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked latent pixel. num_train_timesteps (`int`): Number of diffusion steps used to train the model. alpha_cum_start (`float`): The starting cumulative alpha value. alpha_cum_end (`float`): The ending cumulative alpha value. gamma_cum_start (`float`): The starting cumulative gamma value. gamma_cum_end (`float`): The ending cumulative gamma value. """ order = 1 @register_to_config def __init__( self, num_vec_classes: int, num_train_timesteps: int = 100, alpha_cum_start: float = 0.99999, alpha_cum_end: float = 0.000009, gamma_cum_start: float = 0.000009, gamma_cum_end: float = 0.99999, ): self.num_embed = num_vec_classes # By convention, the index for the mask class is the last class index self.mask_class = self.num_embed - 1 at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end) ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end) num_non_mask_classes = self.num_embed - 1 bt = (1 - at - ct) / num_non_mask_classes btt = (1 - att - ctt) / num_non_mask_classes at = torch.tensor(at.astype("float64")) bt = torch.tensor(bt.astype("float64")) ct = torch.tensor(ct.astype("float64")) log_at = torch.log(at) log_bt = torch.log(bt) log_ct = torch.log(ct) att = torch.tensor(att.astype("float64")) btt = torch.tensor(btt.astype("float64")) ctt = torch.tensor(ctt.astype("float64")) log_cumprod_at = torch.log(att) log_cumprod_bt = torch.log(btt) log_cumprod_ct = torch.log(ctt) self.log_at = log_at.float() self.log_bt = log_bt.float() self.log_ct = log_ct.float() self.log_cumprod_at = log_cumprod_at.float() self.log_cumprod_bt = log_cumprod_bt.float() self.log_cumprod_ct = log_cumprod_ct.float() # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`): device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on. """ self.num_inference_steps = num_inference_steps timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = torch.from_numpy(timesteps).to(device) self.log_at = self.log_at.to(device) self.log_bt = self.log_bt.to(device) self.log_ct = self.log_ct.to(device) self.log_cumprod_at = self.log_cumprod_at.to(device) self.log_cumprod_bt = self.log_cumprod_bt.to(device) self.log_cumprod_ct = self.log_cumprod_ct.to(device) def step( self, model_output: torch.FloatTensor, timestep: torch.long, sample: torch.LongTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[VQDiffusionSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed. Args: log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): The log probabilities for the predicted classes of the initial latent pixels. Does not include a prediction for the masked class as the initial unnoised image cannot be masked. t (`torch.long`): The timestep that determines which transition matrices are used. x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t` generator: (`torch.Generator` or None): RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from. return_dict (`bool`): option for returning tuple rather than VQDiffusionSchedulerOutput class Returns: [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if timestep == 0: log_p_x_t_min_1 = model_output else: log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep) log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) x_t_min_1 = log_p_x_t_min_1.argmax(dim=1) if not return_dict: return (x_t_min_1,) return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1) def q_posterior(self, log_p_x_0, x_t, t): """ Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only forward probabilities. Equation (11) stated in terms of forward probabilities via Equation (5): Where: - the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0) p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) Args: log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`): The log probabilities for the predicted classes of the initial latent pixels. Does not include a prediction for the masked class as the initial unnoised image cannot be masked. x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t` t (torch.Long): The timestep that determines which transition matrix is used. Returns: `torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`: The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11). """ log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class( t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True ) log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class( t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False ) # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) # . . . # . . . # . . . # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) q = log_p_x_0 - log_q_x_t_given_x_0 # sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... , # sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) # p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n # . . . # . . . # . . . # p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n q = q - q_log_sum_exp # (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} # . . . # . . . # . . . # (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1} # c_cumulative_{t-1} ... c_cumulative_{t-1} q = self.apply_cumulative_transitions(q, t - 1) # ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n # . . . # . . . # . . . # ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n # c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp # For each column, there are two possible cases. # # Where: # - sum(p_n(x_0))) is summing over all classes for x_0 # - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's) # - C_j is the class transitioning to # # 1. x_t is masked i.e. x_t = c_k # # Simplifying the expression, the column vector is: # . # . # . # (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0))) # . # . # . # (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0)) # # From equation (11) stated in terms of forward probabilities, the last row is trivially verified. # # For the other rows, we can state the equation as ... # # (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})] # # This verifies the other rows. # # 2. x_t is not masked # # Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i: # . # . # . # C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) # . # . # . # C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1})) # . # . # . # 0 # # The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities. return log_p_x_t_min_1 def log_Q_t_transitioning_to_known_class( self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool ): """ Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each latent pixel in `x_t`. See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs. Args: t (torch.Long): The timestep that determines which transition matrix is used. x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): The classes of each latent pixel at time `t`. log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`): The log one-hot vectors of `x_t` cumulative (`bool`): If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`, we use the cumulative transition matrix `0`->`t`. Returns: `torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`: Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability transition matrix. When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be masked. Where: - `q_n` is the probability distribution for the forward process of the `n`th latent pixel. - C_0 is a class of a latent pixel embedding - C_k is the class of the masked latent pixel non-cumulative result (omitting logarithms): ``` q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0) . . . . . . . . . q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k) ``` cumulative result (omitting logarithms): ``` q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0) . . . . . . . . . q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1}) ``` """ if cumulative: a = self.log_cumprod_at[t] b = self.log_cumprod_bt[t] c = self.log_cumprod_ct[t] else: a = self.log_at[t] b = self.log_bt[t] c = self.log_ct[t] if not cumulative: # The values in the onehot vector can also be used as the logprobs for transitioning # from masked latent pixels. If we are not calculating the cumulative transitions, # we need to save these vectors to be re-appended to the final matrix so the values # aren't overwritten. # # `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector # if x_t is not masked # # `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector # if x_t is masked log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1) # `index_to_log_onehot` will add onehot vectors for masked pixels, # so the default one hot matrix has one too many rows. See the doc string # for an explanation of the dimensionality of the returned matrix. log_onehot_x_t = log_onehot_x_t[:, :-1, :] # this is a cheeky trick to produce the transition probabilities using log one-hot vectors. # # Don't worry about what values this sets in the columns that mark transitions # to masked latent pixels. They are overwrote later with the `mask_class_mask`. # # Looking at the below logspace formula in non-logspace, each value will evaluate to either # `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column # or # `0 * a + b = b` where `log_Q_t` has the 0 values in the column. # # See equation 7 for more details. log_Q_t = (log_onehot_x_t + a).logaddexp(b) # The whole column of each masked pixel is `c` mask_class_mask = x_t == self.mask_class mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1) log_Q_t[mask_class_mask] = c if not cumulative: log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1) return log_Q_t def apply_cumulative_transitions(self, q, t): bsz = q.shape[0] a = self.log_cumprod_at[t] b = self.log_cumprod_bt[t] c = self.log_cumprod_ct[t] num_latent_pixels = q.shape[2] c = c.expand(bsz, 1, num_latent_pixels) q = (q + a).logaddexp(b) q = torch.cat((q, c), dim=1) return q ================================================ FILE: 6DoF/diffusers/training_utils.py ================================================ import contextlib import copy import random from typing import Any, Dict, Iterable, Optional, Union import numpy as np import torch from .utils import deprecate, is_transformers_available if is_transformers_available(): import transformers def set_seed(seed: int): """ Args: Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. seed (`int`): The seed to set. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # ^^ safe to call this function even if cuda is not available # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights """ def __init__( self, parameters: Iterable[torch.nn.Parameter], decay: float = 0.9999, min_decay: float = 0.0, update_after_step: int = 0, use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, model_cls: Optional[Any] = None, model_config: Dict[str, Any] = None, **kwargs, ): """ Args: parameters (Iterable[torch.nn.Parameter]): The parameters to track. decay (float): The decay factor for the exponential moving average. min_decay (float): The minimum decay factor for the exponential moving average. update_after_step (int): The number of steps to wait before starting to update the EMA weights. use_ema_warmup (bool): Whether to use EMA warmup. inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA weights will be stored on CPU. @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). """ if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " "Please pass the parameters of the module instead." ) deprecate( "passing a `torch.nn.Module` to `ExponentialMovingAverage`", "1.0.0", deprecation_message, standard_warn=False, ) parameters = parameters.parameters() # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility use_ema_warmup = True if kwargs.get("max_value", None) is not None: deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) decay = kwargs["max_value"] if kwargs.get("min_value", None) is not None: deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) min_decay = kwargs["min_value"] parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] if kwargs.get("device", None) is not None: deprecation_message = "The `device` argument is deprecated. Please use `to` instead." deprecate("device", "1.0.0", deprecation_message, standard_warn=False) self.to(device=kwargs["device"]) self.temp_stored_params = None self.decay = decay self.min_decay = min_decay self.update_after_step = update_after_step self.use_ema_warmup = use_ema_warmup self.inv_gamma = inv_gamma self.power = power self.optimization_step = 0 self.cur_decay_value = None # set in `step()` self.model_cls = model_cls self.model_config = model_config @classmethod def from_pretrained(cls, path, model_cls) -> "EMAModel": _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path) ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) ema_model.load_state_dict(ema_kwargs) return ema_model def save_pretrained(self, path): if self.model_cls is None: raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") if self.model_config is None: raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") model = self.model_cls.from_config(self.model_config) state_dict = self.state_dict() state_dict.pop("shadow_params", None) model.register_to_config(**state_dict) self.copy_to(model.parameters()) model.save_pretrained(path) def get_decay(self, optimization_step: int) -> float: """ Compute the decay factor for the exponential moving average. """ step = max(0, optimization_step - self.update_after_step - 1) if step <= 0: return 0.0 if self.use_ema_warmup: cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power else: cur_decay_value = (1 + step) / (10 + step) cur_decay_value = min(cur_decay_value, self.decay) # make sure decay is not smaller than min_decay cur_decay_value = max(cur_decay_value, self.min_decay) return cur_decay_value @torch.no_grad() def step(self, parameters: Iterable[torch.nn.Parameter]): if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " "Please pass the parameters of the module instead." ) deprecate( "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", "1.0.0", deprecation_message, standard_warn=False, ) parameters = parameters.parameters() parameters = list(parameters) self.optimization_step += 1 # Compute the decay factor for the exponential moving average. decay = self.get_decay(self.optimization_step) self.cur_decay_value = decay one_minus_decay = 1 - decay context_manager = contextlib.nullcontext if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed for s_param, param in zip(self.shadow_params, parameters): if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) with context_manager(): if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.to(param.device).data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] def state_dict(self) -> dict: r""" Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during checkpointing to save the ema state dict. """ # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "min_decay": self.min_decay, "optimization_step": self.optimization_step, "update_after_step": self.update_after_step, "use_ema_warmup": self.use_ema_warmup, "inv_gamma": self.inv_gamma, "power": self.power, "shadow_params": self.shadow_params, } def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Save the current parameters for restoring later. parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the former parameters. parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") for c_param, param in zip(self.temp_stored_params, parameters): param.data.copy_(c_param.data) # Better memory-wise. self.temp_stored_params = None def load_state_dict(self, state_dict: dict) -> None: r""" Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict.get("decay", self.decay) if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.min_decay = state_dict.get("min_decay", self.min_decay) if not isinstance(self.min_decay, float): raise ValueError("Invalid min_decay") self.optimization_step = state_dict.get("optimization_step", self.optimization_step) if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") self.update_after_step = state_dict.get("update_after_step", self.update_after_step) if not isinstance(self.update_after_step, int): raise ValueError("Invalid update_after_step") self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) if not isinstance(self.use_ema_warmup, bool): raise ValueError("Invalid use_ema_warmup") self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) if not isinstance(self.inv_gamma, (float, int)): raise ValueError("Invalid inv_gamma") self.power = state_dict.get("power", self.power) if not isinstance(self.power, (float, int)): raise ValueError("Invalid power") shadow_params = state_dict.get("shadow_params", None) if shadow_params is not None: self.shadow_params = shadow_params if not isinstance(self.shadow_params, list): raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors") ================================================ FILE: 6DoF/diffusers/utils/__init__.py ================================================ # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from packaging import version from .. import __version__ from .accelerate_utils import apply_forward_hook from .constants import ( CONFIG_NAME, DEPRECATED_REVISION_ARGS, DIFFUSERS_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, FLAX_WEIGHTS_NAME, HF_MODULES_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, TEXT_ENCODER_ATTN_MODULE, WEIGHTS_NAME, ) from .deprecation_utils import deprecate from .doc_utils import replace_example_docstring from .dynamic_modules_utils import get_class_from_dynamic_module from .hub_utils import ( HF_HUB_OFFLINE, _add_variant, _get_model_file, extract_commit_hash, http_user_agent, ) from .import_utils import ( BACKENDS_MAPPING, ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, USE_JAX, USE_TF, USE_TORCH, DummyObject, OptionalDependencyNotAvailable, is_accelerate_available, is_accelerate_version, is_bs4_available, is_flax_available, is_ftfy_available, is_inflect_available, is_invisible_watermark_available, is_k_diffusion_available, is_k_diffusion_version, is_librosa_available, is_note_seq_available, is_omegaconf_available, is_onnx_available, is_safetensors_available, is_scipy_available, is_tensorboard_available, is_tf_available, is_torch_available, is_torch_version, is_torchsde_available, is_transformers_available, is_transformers_version, is_unidecode_available, is_wandb_available, is_xformers_available, requires_backends, ) from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil from .torch_utils import is_compiled_module, randn_tensor if is_torch_available(): from .testing_utils import ( floats_tensor, load_hf_numpy, load_image, load_numpy, load_pt, nightly, parse_flag_from_env, print_tensor_test, require_torch_2, require_torch_gpu, skip_mps, slow, torch_all_close, torch_device, ) from .torch_utils import maybe_allow_in_graph from .testing_utils import export_to_gif, export_to_video logger = get_logger(__name__) def check_min_version(min_version): if version.parse(__version__) < version.parse(min_version): if "dev" in min_version: error_message = ( "This example requires a source install from HuggingFace diffusers (see " "`https://huggingface.co/docs/diffusers/installation#install-from-source`)," ) else: error_message = f"This example requires a minimum version of {min_version}," error_message += f" but the version found is {__version__}.\n" raise ImportError(error_message) ================================================ FILE: 6DoF/diffusers/utils/accelerate_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Accelerate utilities: Utilities related to accelerate """ from packaging import version from .import_utils import is_accelerate_available if is_accelerate_available(): import accelerate def apply_forward_hook(method): """ Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. :param method: The method to decorate. This method should be a method of a PyTorch module. """ if not is_accelerate_available(): return method accelerate_version = version.parse(accelerate.__version__).base_version if version.parse(accelerate_version) < version.parse("0.17.0"): return method def wrapper(self, *args, **kwargs): if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): self._hf_hook.pre_forward(self) return method(self, *args, **kwargs) return wrapper ================================================ FILE: 6DoF/diffusers/utils/constants.py ================================================ # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home default_cache_path = HUGGINGFACE_HUB_CACHE CONFIG_NAME = "config.json" WEIGHTS_NAME = "diffusion_pytorch_model.bin" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] TEXT_ENCODER_ATTN_MODULE = ".self_attn" ================================================ FILE: 6DoF/diffusers/utils/deprecation_utils.py ================================================ import inspect import warnings from typing import Any, Dict, Optional, Union from packaging import version def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): from .. import __version__ deprecated_kwargs = take_from values = () if not isinstance(args[0], tuple): args = (args,) for attribute, version_name, message in args: if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): raise ValueError( f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" f" version {__version__} is >= {version_name}" ) warning = None if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: values += (deprecated_kwargs.pop(attribute),) warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." elif hasattr(deprecated_kwargs, attribute): values += (getattr(deprecated_kwargs, attribute),) warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." elif deprecated_kwargs is None: warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." if warning is not None: warning = warning + " " if standard_warn else "" warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel) if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: call_frame = inspect.getouterframes(inspect.currentframe())[1] filename = call_frame.filename line_number = call_frame.lineno function = call_frame.function key, value = next(iter(deprecated_kwargs.items())) raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") if len(values) == 0: return elif len(values) == 1: return values[0] return values ================================================ FILE: 6DoF/diffusers/utils/doc_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Doc utilities: Utilities related to documentation """ import re def replace_example_docstring(example_docstring): def docstring_decorator(fn): func_doc = fn.__doc__ lines = func_doc.split("\n") i = 0 while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None: i += 1 if i < len(lines): lines[i] = example_docstring func_doc = "\n".join(lines) else: raise ValueError( f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, " f"current docstring is:\n{func_doc}" ) fn.__doc__ = func_doc return fn return docstring_decorator ================================================ FILE: 6DoF/diffusers/utils/dummy_flax_and_transformers_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class FlaxStableDiffusionControlNetPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) class FlaxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) class FlaxStableDiffusionInpaintPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) class FlaxStableDiffusionPipeline(metaclass=DummyObject): _backends = ["flax", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax", "transformers"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_flax_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class FlaxControlNetModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxModelMixin(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxUNet2DConditionModel(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxAutoencoderKL(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDiffusionPipeline(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDDIMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDDPMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxKarrasVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxLMSDiscreteScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxPNDMScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxSchedulerMixin(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"] def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["flax"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_note_seq_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class MidiProcessor(metaclass=DummyObject): _backends = ["note_seq"] def __init__(self, *args, **kwargs): requires_backends(self, ["note_seq"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["note_seq"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["note_seq"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_onnx_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class OnnxRuntimeModel(metaclass=DummyObject): _backends = ["onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["onnx"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_pt_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class AutoencoderKL(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ControlNetModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ModelMixin(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class PriorTransformer(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class T5FilmDecoder(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class Transformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet1DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet2DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UNet3DConditionModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class VQModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) def get_constant_schedule_with_warmup(*args, **kwargs): requires_backends(get_constant_schedule_with_warmup, ["torch"]) def get_cosine_schedule_with_warmup(*args, **kwargs): requires_backends(get_cosine_schedule_with_warmup, ["torch"]) def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) def get_linear_schedule_with_warmup(*args, **kwargs): requires_backends(get_linear_schedule_with_warmup, ["torch"]) def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"]) def get_scheduler(*args, **kwargs): requires_backends(get_scheduler, ["torch"]) class AudioPipelineOutput(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ConsistencyModelPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DanceDiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDPMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DiTPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ImagePipelineOutput(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KarrasVePipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class LDMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class LDMSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class PNDMPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class RePaintPipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ScoreSdeVePipeline(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class CMStochasticIterativeScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMInverseScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMParallelScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDIMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDPMParallelScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DDPMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DEISMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class DPMSolverSinglestepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EulerAncestralDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EulerDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class HeunDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class IPNDMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KarrasVeScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KDPM2AncestralDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class KDPM2DiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class PNDMScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class RePaintScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class SchedulerMixin(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class ScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UnCLIPScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class UniPCMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class VQDiffusionScheduler(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) class EMAModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_torch_and_librosa_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class AudioDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "librosa"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "librosa"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) class Mel(metaclass=DummyObject): _backends = ["torch", "librosa"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "librosa"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "librosa"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_torch_and_scipy_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class LMSDiscreteScheduler(metaclass=DummyObject): _backends = ["torch", "scipy"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "scipy"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "scipy"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "scipy"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_torch_and_torchsde_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class DPMSolverSDEScheduler(metaclass=DummyObject): _backends = ["torch", "torchsde"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "torchsde"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "torchsde"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "torchsde"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "invisible_watermark"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) class StableDiffusionXLPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "invisible_watermark"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_torch_and_transformers_and_k_diffusion_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class StableDiffusionKDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "k_diffusion"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "k_diffusion"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "k_diffusion"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "k_diffusion"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class OnnxStableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) class StableDiffusionOnnxPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers", "onnx"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers", "onnx"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_torch_and_transformers_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AltDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class AudioLDMPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFInpaintingPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class IFSuperResolutionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class ImageTextPipelineOutput(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyPriorPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22ControlnetImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22ControlnetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22Img2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22InpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22PriorEmb2EmbPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class KandinskyV22PriorPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class LDMTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class PaintByExamplePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class SemanticStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class ShapEImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class ShapEPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionDiffEditPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionInstructPix2PixPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionLatentUpscalePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionLDM3DPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionModelEditingPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPanoramaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionParadigmsPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPipelineSafe(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionSAGPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableDiffusionUpscalePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class StableUnCLIPPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class TextToVideoSDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class TextToVideoZeroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UnCLIPImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UnCLIPPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UniDiffuserModel(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UniDiffuserPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class UniDiffuserTextDecoder(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VideoToVideoSDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) class VQDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["torch", "transformers"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) ================================================ FILE: 6DoF/diffusers/utils/dummy_transformers_and_torch_and_note_seq_objects.py ================================================ # This file is autogenerated by the command `make fix-copies`, do not edit. from ..utils import DummyObject, requires_backends class SpectrogramDiffusionPipeline(metaclass=DummyObject): _backends = ["transformers", "torch", "note_seq"] def __init__(self, *args, **kwargs): requires_backends(self, ["transformers", "torch", "note_seq"]) @classmethod def from_config(cls, *args, **kwargs): requires_backends(cls, ["transformers", "torch", "note_seq"]) @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["transformers", "torch", "note_seq"]) ================================================ FILE: 6DoF/diffusers/utils/dynamic_modules_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities to dynamically load objects from the Hub.""" import importlib import inspect import json import os import re import shutil import sys from pathlib import Path from typing import Dict, Optional, Union from urllib import request from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info from packaging import version from .. import __version__ from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging COMMUNITY_PIPELINES_URL = ( "https://raw.githubusercontent.com/huggingface/diffusers/{revision}/examples/community/{pipeline}.py" ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_diffusers_versions(): url = "https://pypi.org/pypi/diffusers/json" releases = json.loads(request.urlopen(url).read())["releases"].keys() return sorted(releases, key=lambda x: version.Version(x)) def init_hf_modules(): """ Creates the cache directory for modules with an init, and adds it to the Python path. """ # This function has already been executed if HF_MODULES_CACHE already is in the Python path. if HF_MODULES_CACHE in sys.path: return sys.path.append(HF_MODULES_CACHE) os.makedirs(HF_MODULES_CACHE, exist_ok=True) init_path = Path(HF_MODULES_CACHE) / "__init__.py" if not init_path.exists(): init_path.touch() def create_dynamic_module(name: Union[str, os.PathLike]): """ Creates a dynamic module in the cache directory for modules. """ init_hf_modules() dynamic_module_path = Path(HF_MODULES_CACHE) / name # If the parent module does not exist yet, recursively create it. if not dynamic_module_path.parent.exists(): create_dynamic_module(dynamic_module_path.parent) os.makedirs(dynamic_module_path, exist_ok=True) init_path = dynamic_module_path / "__init__.py" if not init_path.exists(): init_path.touch() def get_relative_imports(module_file): """ Get the list of modules that are relatively imported in a module file. Args: module_file (`str` or `os.PathLike`): The module file to inspect. """ with open(module_file, "r", encoding="utf-8") as f: content = f.read() # Imports of the form `import .xxx` relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE) # Imports of the form `from .xxx import yyy` relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE) # Unique-ify return list(set(relative_imports)) def get_relative_import_files(module_file): """ Get the list of all files that are needed for a given module. Note that this function recurses through the relative imports (if a imports b and b imports c, it will return module files for b and c). Args: module_file (`str` or `os.PathLike`): The module file to inspect. """ no_change = False files_to_check = [module_file] all_relative_imports = [] # Let's recurse through all relative imports while not no_change: new_imports = [] for f in files_to_check: new_imports.extend(get_relative_imports(f)) module_path = Path(module_file).parent new_import_files = [str(module_path / m) for m in new_imports] new_import_files = [f for f in new_import_files if f not in all_relative_imports] files_to_check = [f"{f}.py" for f in new_import_files] no_change = len(new_import_files) == 0 all_relative_imports.extend(files_to_check) return all_relative_imports def check_imports(filename): """ Check if the current Python environment contains all the libraries that are imported in a file. """ with open(filename, "r", encoding="utf-8") as f: content = f.read() # Imports of the form `import xxx` imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) # Imports of the form `from xxx import yyy` imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE) # Only keep the top-level module imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")] # Unique-ify and test we got them all imports = list(set(imports)) missing_packages = [] for imp in imports: try: importlib.import_module(imp) except ImportError: missing_packages.append(imp) if len(missing_packages) > 0: raise ImportError( "This modeling file requires the following packages that were not found in your environment: " f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`" ) return get_relative_imports(filename) def get_class_in_module(class_name, module_path): """ Import a module on the cache directory for modules and extract a class from it. """ module_path = module_path.replace(os.path.sep, ".") module = importlib.import_module(module_path) if class_name is None: return find_pipeline_class(module) return getattr(module, class_name) def find_pipeline_class(loaded_module): """ Retrieve pipeline class that inherits from `DiffusionPipeline`. Note that there has to be exactly one class inheriting from `DiffusionPipeline`. """ from ..pipelines import DiffusionPipeline cls_members = dict(inspect.getmembers(loaded_module, inspect.isclass)) pipeline_class = None for cls_name, cls in cls_members.items(): if ( cls_name != DiffusionPipeline.__name__ and issubclass(cls, DiffusionPipeline) and cls.__module__.split(".")[0] != "diffusers" ): if pipeline_class is not None: raise ValueError( f"Multiple classes that inherit from {DiffusionPipeline.__name__} have been found:" f" {pipeline_class.__name__}, and {cls_name}. Please make sure to define only one in" f" {loaded_module}." ) pipeline_class = cls return pipeline_class def get_cached_module_file( pretrained_model_name_or_path: Union[str, os.PathLike], module_file: str, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, resume_download: bool = False, proxies: Optional[Dict[str, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, ): """ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached Transformers module. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained model configuration hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - a path to a *directory* containing a configuration file saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. module_file (`str`): The name of the module file containing the class to look for. cache_dir (`str` or `os.PathLike`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, will only try to load the tokenizer configuration from local files. You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Returns: `str`: The path to the module inside the cache. """ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file. pretrained_model_name_or_path = str(pretrained_model_name_or_path) module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file) if os.path.isfile(module_file_or_url): resolved_module_file = module_file_or_url submodule = "local" elif pretrained_model_name_or_path.count("/") == 0: available_versions = get_diffusers_versions() # cut ".dev0" latest_version = "v" + ".".join(__version__.split(".")[:3]) # retrieve github version that matches if revision is None: revision = latest_version if latest_version[1:] in available_versions else "main" logger.info(f"Defaulting to latest_version: {revision}.") elif revision in available_versions: revision = f"v{revision}" elif revision == "main": revision = revision else: raise ValueError( f"`custom_revision`: {revision} does not exist. Please make sure to choose one of" f" {', '.join(available_versions + ['main'])}." ) # community pipeline on GitHub github_url = COMMUNITY_PIPELINES_URL.format(revision=revision, pipeline=pretrained_model_name_or_path) try: resolved_module_file = cached_download( github_url, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=False, ) submodule = "git" module_file = pretrained_model_name_or_path + ".py" except EnvironmentError: logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") raise else: try: # Load from URL or cache if already cached resolved_module_file = hf_hub_download( pretrained_model_name_or_path, module_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, ) submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/"))) except EnvironmentError: logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") raise # Check we have all the requirements in our environment modules_needed = check_imports(resolved_module_file) # Now we move the module inside our cached dynamic modules. full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule create_dynamic_module(full_submodule) submodule_path = Path(HF_MODULES_CACHE) / full_submodule if submodule == "local" or submodule == "git": # We always copy local files (we could hash the file to see if there was a change, and give them the name of # that hash, to only copy when there is a modification but it seems overkill for now). # The only reason we do the copy is to avoid putting too many folders in sys.path. shutil.copy(resolved_module_file, submodule_path / module_file) for module_needed in modules_needed: module_needed = f"{module_needed}.py" shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed) else: # Get the commit hash # TODO: we will get this info in the etag soon, so retrieve it from there and not here. if isinstance(use_auth_token, str): token = use_auth_token elif use_auth_token is True: token = HfFolder.get_token() else: token = None commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the # benefit of versioning. submodule_path = submodule_path / commit_hash full_submodule = full_submodule + os.path.sep + commit_hash create_dynamic_module(full_submodule) if not (submodule_path / module_file).exists(): shutil.copy(resolved_module_file, submodule_path / module_file) # Make sure we also have every file with relative for module_needed in modules_needed: if not (submodule_path / module_needed).exists(): get_cached_module_file( pretrained_model_name_or_path, f"{module_needed}.py", cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, use_auth_token=use_auth_token, revision=revision, local_files_only=local_files_only, ) return os.path.join(full_submodule, module_file) def get_class_from_dynamic_module( pretrained_model_name_or_path: Union[str, os.PathLike], module_file: str, class_name: Optional[str] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, force_download: bool = False, resume_download: bool = False, proxies: Optional[Dict[str, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, local_files_only: bool = False, **kwargs, ): """ Extracts a class from a module file, present in the local folder or repository of a model. Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should therefore only be called on trusted repos. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained model configuration hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - a path to a *directory* containing a configuration file saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. module_file (`str`): The name of the module file containing the class to look for. class_name (`str`): The name of the class to import in the module. cache_dir (`str` or `os.PathLike`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. use_auth_token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. local_files_only (`bool`, *optional*, defaults to `False`): If `True`, will only try to load the tokenizer configuration from local files. You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models). Returns: `type`: The class, dynamically imported from the module. Examples: ```python # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this # module. cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel") ```""" # And lastly we get the class inside our newly created module final_module = get_cached_module_file( pretrained_model_name_or_path, module_file, cache_dir=cache_dir, force_download=force_download, resume_download=resume_download, proxies=proxies, use_auth_token=use_auth_token, revision=revision, local_files_only=local_files_only, ) return get_class_in_module(class_name, final_module.replace(".py", "")) ================================================ FILE: 6DoF/diffusers/utils/hub_utils.py ================================================ # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import sys import traceback import warnings from pathlib import Path from typing import Dict, Optional, Union from uuid import uuid4 from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.utils import ( EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError, is_jinja_available, ) from packaging import version from requests import HTTPError from .. import __version__ from .constants import ( DEPRECATED_REVISION_ARGS, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, ) from .import_utils import ( ENV_VARS_TRUE_VALUES, _flax_version, _jax_version, _onnxruntime_version, _torch_version, is_flax_available, is_onnx_available, is_torch_available, ) from .logging import get_logger logger = get_logger(__name__) MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md" SESSION_ID = uuid4().hex HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/" def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: """ Formats a user-agent string with basic info about a request. """ ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" if DISABLE_TELEMETRY or HF_HUB_OFFLINE: return ua + "; telemetry/off" if is_torch_available(): ua += f"; torch/{_torch_version}" if is_flax_available(): ua += f"; jax/{_jax_version}" ua += f"; flax/{_flax_version}" if is_onnx_available(): ua += f"; onnxruntime/{_onnxruntime_version}" # CI will set this value to True if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: ua += "; is_ci/true" if isinstance(user_agent, dict): ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) elif isinstance(user_agent, str): ua += "; " + user_agent return ua def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() if organization is None: username = whoami(token)["name"] return f"{username}/{model_id}" else: return f"{organization}/{model_id}" def create_model_card(args, model_name): if not is_jinja_available(): raise ValueError( "Modelcard rendering is based on Jinja templates." " Please make sure to have `jinja` installed before using `create_model_card`." " To install it, please run `pip install Jinja2`." ) if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: return hub_token = args.hub_token if hasattr(args, "hub_token") else None repo_name = get_full_repo_name(model_name, token=hub_token) model_card = ModelCard.from_template( card_data=ModelCardData( # Card metadata object that will be converted to YAML block language="en", license="apache-2.0", library_name="diffusers", tags=[], datasets=args.dataset_name, metrics=[], ), template_path=MODEL_CARD_TEMPLATE_PATH, model_name=model_name, repo_name=repo_name, dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None, learning_rate=args.learning_rate, train_batch_size=args.train_batch_size, eval_batch_size=args.eval_batch_size, gradient_accumulation_steps=( args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None ), adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None, adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None, lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None, lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None, ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None, ema_power=args.ema_power if hasattr(args, "ema_power") else None, ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None, mixed_precision=args.mixed_precision, ) card_path = os.path.join(args.output_dir, "README.md") model_card.save(card_path) def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None): """ Extracts the commit hash from a resolved filename toward a cache file. """ if resolved_file is None or commit_hash is not None: return commit_hash resolved_file = str(Path(resolved_file).as_posix()) search = re.search(r"snapshots/([^/]+)/", resolved_file) if search is None: return None commit_hash = search.groups()[0] return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None # Old default cache path, potentially to be migrated. # This logic was more or less taken from `transformers`, with the following differences: # - Diffusers doesn't use custom environment variables to specify the cache path. # - There is no need to migrate the cache format, just move the files to the new location. hf_cache_home = os.path.expanduser( os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) ) old_diffusers_cache = os.path.join(hf_cache_home, "diffusers") def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None: if new_cache_dir is None: new_cache_dir = DIFFUSERS_CACHE if old_cache_dir is None: old_cache_dir = old_diffusers_cache old_cache_dir = Path(old_cache_dir).expanduser() new_cache_dir = Path(new_cache_dir).expanduser() for old_blob_path in old_cache_dir.glob("**/blobs/*"): if old_blob_path.is_file() and not old_blob_path.is_symlink(): new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir) new_blob_path.parent.mkdir(parents=True, exist_ok=True) os.replace(old_blob_path, new_blob_path) try: os.symlink(new_blob_path, old_blob_path) except OSError: logger.warning( "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded." ) # At this point, old_cache_dir contains symlinks to the new cache (it can still be used). cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt") if not os.path.isfile(cache_version_file): cache_version = 0 else: with open(cache_version_file) as f: try: cache_version = int(f.read()) except ValueError: cache_version = 0 if cache_version < 1: old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0 if old_cache_is_not_empty: logger.warning( "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your " "existing cached models. This is a one-time operation, you can interrupt it or run it " "later by calling `diffusers.utils.hub_utils.move_cache()`." ) try: move_cache() except Exception as e: trace = "\n".join(traceback.format_tb(e.__traceback__)) logger.error( f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease " "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole " "message and we will do our best to help." ) if cache_version < 1: try: os.makedirs(DIFFUSERS_CACHE, exist_ok=True) with open(cache_version_file, "w") as f: f.write("1") except Exception: logger.warning( f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure " "the directory exists and can be written to." ) def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: splits = weights_name.split(".") splits = splits[:-1] + [variant] + splits[-1:] weights_name = ".".join(splits) return weights_name def _get_model_file( pretrained_model_name_or_path, *, weights_name, subfolder, cache_dir, force_download, proxies, resume_download, local_files_only, use_auth_token, user_agent, revision, commit_hash=None, ): pretrained_model_name_or_path = str(pretrained_model_name_or_path) if os.path.isfile(pretrained_model_name_or_path): return pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): # Load from a PyTorch checkpoint model_file = os.path.join(pretrained_model_name_or_path, weights_name) return model_file elif subfolder is not None and os.path.isfile( os.path.join(pretrained_model_name_or_path, subfolder, weights_name) ): model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) return model_file else: raise EnvironmentError( f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." ) else: # 1. First check if deprecated way of loading from branches is used if ( revision in DEPRECATED_REVISION_ARGS and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) and version.parse(version.parse(__version__).base_version) >= version.parse("0.20.0") ): try: model_file = hf_hub_download( pretrained_model_name_or_path, filename=_add_variant(weights_name, revision), cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, subfolder=subfolder, revision=revision or commit_hash, ) warnings.warn( f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.", FutureWarning, ) return model_file except: # noqa: E722 warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.", FutureWarning, ) try: # 2. Load model file as usual model_file = hf_hub_download( pretrained_model_name_or_path, filename=weights_name, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, subfolder=subfolder, revision=revision or commit_hash, ) return model_file except RepositoryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " "login`." ) except RevisionNotFoundError: raise EnvironmentError( f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " "this model name. Check the model page at " f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." ) except EntryNotFoundError: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." ) except HTTPError as err: raise EnvironmentError( f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" ) except ValueError: raise EnvironmentError( f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" directory containing a file named {weights_name} or" " \nCheckout your internet connection or see how to run the library in" " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a file named {weights_name}" ) ================================================ FILE: 6DoF/diffusers/utils/import_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Import utilities: Utilities related to imports and our lazy inits. """ import importlib.util import operator as op import os import sys from collections import OrderedDict from typing import Union from huggingface_hub.utils import is_jinja_available # noqa: F401 from packaging import version from packaging.version import Version, parse from . import logging # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata logger = logging.get_logger(__name__) # pylint: disable=invalid-name ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper() STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None if _torch_available: try: _torch_version = importlib_metadata.version("torch") logger.info(f"PyTorch version {_torch_version} available.") except importlib_metadata.PackageNotFoundError: _torch_available = False else: logger.info("Disabling PyTorch because USE_TORCH is set") _torch_available = False _tf_version = "N/A" if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: _tf_available = importlib.util.find_spec("tensorflow") is not None if _tf_available: candidates = ( "tensorflow", "tensorflow-cpu", "tensorflow-gpu", "tf-nightly", "tf-nightly-cpu", "tf-nightly-gpu", "intel-tensorflow", "intel-tensorflow-avx512", "tensorflow-rocm", "tensorflow-macos", "tensorflow-aarch64", ) _tf_version = None # For the metadata, we have to look for both tensorflow and tensorflow-cpu for pkg in candidates: try: _tf_version = importlib_metadata.version(pkg) break except importlib_metadata.PackageNotFoundError: pass _tf_available = _tf_version is not None if _tf_available: if version.parse(_tf_version) < version.parse("2"): logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.") _tf_available = False else: logger.info(f"TensorFlow version {_tf_version} available.") else: logger.info("Disabling Tensorflow because USE_TORCH is set") _tf_available = False _jax_version = "N/A" _flax_version = "N/A" if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None if _flax_available: try: _jax_version = importlib_metadata.version("jax") _flax_version = importlib_metadata.version("flax") logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") except importlib_metadata.PackageNotFoundError: _flax_available = False else: _flax_available = False if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: _safetensors_available = importlib.util.find_spec("safetensors") is not None if _safetensors_available: try: _safetensors_version = importlib_metadata.version("safetensors") logger.info(f"Safetensors version {_safetensors_version} available.") except importlib_metadata.PackageNotFoundError: _safetensors_available = False else: logger.info("Disabling Safetensors because USE_TF is set") _safetensors_available = False _transformers_available = importlib.util.find_spec("transformers") is not None try: _transformers_version = importlib_metadata.version("transformers") logger.debug(f"Successfully imported transformers version {_transformers_version}") except importlib_metadata.PackageNotFoundError: _transformers_available = False _inflect_available = importlib.util.find_spec("inflect") is not None try: _inflect_version = importlib_metadata.version("inflect") logger.debug(f"Successfully imported inflect version {_inflect_version}") except importlib_metadata.PackageNotFoundError: _inflect_available = False _unidecode_available = importlib.util.find_spec("unidecode") is not None try: _unidecode_version = importlib_metadata.version("unidecode") logger.debug(f"Successfully imported unidecode version {_unidecode_version}") except importlib_metadata.PackageNotFoundError: _unidecode_available = False _onnxruntime_version = "N/A" _onnx_available = importlib.util.find_spec("onnxruntime") is not None if _onnx_available: candidates = ( "onnxruntime", "onnxruntime-gpu", "ort_nightly_gpu", "onnxruntime-directml", "onnxruntime-openvino", "ort_nightly_directml", "onnxruntime-rocm", "onnxruntime-training", ) _onnxruntime_version = None # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu for pkg in candidates: try: _onnxruntime_version = importlib_metadata.version(pkg) break except importlib_metadata.PackageNotFoundError: pass _onnx_available = _onnxruntime_version is not None if _onnx_available: logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") # (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. # _opencv_available = importlib.util.find_spec("opencv-python") is not None try: candidates = ( "opencv-python", "opencv-contrib-python", "opencv-python-headless", "opencv-contrib-python-headless", ) _opencv_version = None for pkg in candidates: try: _opencv_version = importlib_metadata.version(pkg) break except importlib_metadata.PackageNotFoundError: pass _opencv_available = _opencv_version is not None if _opencv_available: logger.debug(f"Successfully imported cv2 version {_opencv_version}") except importlib_metadata.PackageNotFoundError: _opencv_available = False _scipy_available = importlib.util.find_spec("scipy") is not None try: _scipy_version = importlib_metadata.version("scipy") logger.debug(f"Successfully imported scipy version {_scipy_version}") except importlib_metadata.PackageNotFoundError: _scipy_available = False _librosa_available = importlib.util.find_spec("librosa") is not None try: _librosa_version = importlib_metadata.version("librosa") logger.debug(f"Successfully imported librosa version {_librosa_version}") except importlib_metadata.PackageNotFoundError: _librosa_available = False _accelerate_available = importlib.util.find_spec("accelerate") is not None try: _accelerate_version = importlib_metadata.version("accelerate") logger.debug(f"Successfully imported accelerate version {_accelerate_version}") except importlib_metadata.PackageNotFoundError: _accelerate_available = False _xformers_available = importlib.util.find_spec("xformers") is not None try: _xformers_version = importlib_metadata.version("xformers") if _torch_available: import torch if version.Version(torch.__version__) < version.Version("1.12"): raise ValueError("PyTorch should be >= 1.12") logger.debug(f"Successfully imported xformers version {_xformers_version}") except importlib_metadata.PackageNotFoundError: _xformers_available = False _k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None try: _k_diffusion_version = importlib_metadata.version("k_diffusion") logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") except importlib_metadata.PackageNotFoundError: _k_diffusion_available = False _note_seq_available = importlib.util.find_spec("note_seq") is not None try: _note_seq_version = importlib_metadata.version("note_seq") logger.debug(f"Successfully imported note-seq version {_note_seq_version}") except importlib_metadata.PackageNotFoundError: _note_seq_available = False _wandb_available = importlib.util.find_spec("wandb") is not None try: _wandb_version = importlib_metadata.version("wandb") logger.debug(f"Successfully imported wandb version {_wandb_version }") except importlib_metadata.PackageNotFoundError: _wandb_available = False _omegaconf_available = importlib.util.find_spec("omegaconf") is not None try: _omegaconf_version = importlib_metadata.version("omegaconf") logger.debug(f"Successfully imported omegaconf version {_omegaconf_version}") except importlib_metadata.PackageNotFoundError: _omegaconf_available = False _tensorboard_available = importlib.util.find_spec("tensorboard") try: _tensorboard_version = importlib_metadata.version("tensorboard") logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") except importlib_metadata.PackageNotFoundError: _tensorboard_available = False _compel_available = importlib.util.find_spec("compel") try: _compel_version = importlib_metadata.version("compel") logger.debug(f"Successfully imported compel version {_compel_version}") except importlib_metadata.PackageNotFoundError: _compel_available = False _ftfy_available = importlib.util.find_spec("ftfy") is not None try: _ftfy_version = importlib_metadata.version("ftfy") logger.debug(f"Successfully imported ftfy version {_ftfy_version}") except importlib_metadata.PackageNotFoundError: _ftfy_available = False _bs4_available = importlib.util.find_spec("bs4") is not None try: # importlib metadata under different name _bs4_version = importlib_metadata.version("beautifulsoup4") logger.debug(f"Successfully imported ftfy version {_bs4_version}") except importlib_metadata.PackageNotFoundError: _bs4_available = False _torchsde_available = importlib.util.find_spec("torchsde") is not None try: _torchsde_version = importlib_metadata.version("torchsde") logger.debug(f"Successfully imported torchsde version {_torchsde_version}") except importlib_metadata.PackageNotFoundError: _torchsde_available = False _invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None try: _invisible_watermark_version = importlib_metadata.version("invisible-watermark") logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}") except importlib_metadata.PackageNotFoundError: _invisible_watermark_available = False def is_torch_available(): return _torch_available def is_safetensors_available(): return _safetensors_available def is_tf_available(): return _tf_available def is_flax_available(): return _flax_available def is_transformers_available(): return _transformers_available def is_inflect_available(): return _inflect_available def is_unidecode_available(): return _unidecode_available def is_onnx_available(): return _onnx_available def is_opencv_available(): return _opencv_available def is_scipy_available(): return _scipy_available def is_librosa_available(): return _librosa_available def is_xformers_available(): return _xformers_available def is_accelerate_available(): return _accelerate_available def is_k_diffusion_available(): return _k_diffusion_available def is_note_seq_available(): return _note_seq_available def is_wandb_available(): return _wandb_available def is_omegaconf_available(): return _omegaconf_available def is_tensorboard_available(): return _tensorboard_available def is_compel_available(): return _compel_available def is_ftfy_available(): return _ftfy_available def is_bs4_available(): return _bs4_available def is_torchsde_available(): return _torchsde_available def is_invisible_watermark_available(): return _invisible_watermark_available # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the installation page: https://github.com/google/flax and follow the ones that match your environment. """ # docstyle-ignore INFLECT_IMPORT_ERROR = """ {0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install inflect` """ # docstyle-ignore PYTORCH_IMPORT_ERROR = """ {0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. """ # docstyle-ignore ONNX_IMPORT_ERROR = """ {0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip install onnxruntime` """ # docstyle-ignore OPENCV_IMPORT_ERROR = """ {0} requires the OpenCV library but it was not found in your environment. You can install it with pip: `pip install opencv-python` """ # docstyle-ignore SCIPY_IMPORT_ERROR = """ {0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install scipy` """ # docstyle-ignore LIBROSA_IMPORT_ERROR = """ {0} requires the librosa library but it was not found in your environment. Checkout the instructions on the installation page: https://librosa.org/doc/latest/install.html and follow the ones that match your environment. """ # docstyle-ignore TRANSFORMERS_IMPORT_ERROR = """ {0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip install transformers` """ # docstyle-ignore UNIDECODE_IMPORT_ERROR = """ {0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install Unidecode` """ # docstyle-ignore K_DIFFUSION_IMPORT_ERROR = """ {0} requires the k-diffusion library but it was not found in your environment. You can install it with pip: `pip install k-diffusion` """ # docstyle-ignore NOTE_SEQ_IMPORT_ERROR = """ {0} requires the note-seq library but it was not found in your environment. You can install it with pip: `pip install note-seq` """ # docstyle-ignore WANDB_IMPORT_ERROR = """ {0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip install wandb` """ # docstyle-ignore OMEGACONF_IMPORT_ERROR = """ {0} requires the omegaconf library but it was not found in your environment. You can install it with pip: `pip install omegaconf` """ # docstyle-ignore TENSORBOARD_IMPORT_ERROR = """ {0} requires the tensorboard library but it was not found in your environment. You can install it with pip: `pip install tensorboard` """ # docstyle-ignore COMPEL_IMPORT_ERROR = """ {0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel` """ # docstyle-ignore BS4_IMPORT_ERROR = """ {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: `pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. """ # docstyle-ignore FTFY_IMPORT_ERROR = """ {0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones that match your environment. Please note that you may need to restart your runtime after installation. """ # docstyle-ignore TORCHSDE_IMPORT_ERROR = """ {0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` """ # docstyle-ignore INVISIBLE_WATERMARK_IMPORT_ERROR = """ {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=2.0` """ BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), ] ) def requires_backends(obj, backends): if not isinstance(backends, (list, tuple)): backends = [backends] name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ checks = (BACKENDS_MAPPING[backend] for backend in backends) failed = [msg.format(name) for available, msg in checks if not available()] if failed: raise ImportError("".join(failed)) if name in [ "VersatileDiffusionTextToImagePipeline", "VersatileDiffusionPipeline", "VersatileDiffusionDualGuidedPipeline", "StableDiffusionImageVariationPipeline", "UnCLIPPipeline", ] and is_transformers_version("<", "4.25.0"): raise ImportError( f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install" " --upgrade transformers \n```" ) if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version( "<", "4.26.0" ): raise ImportError( f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install" " --upgrade transformers \n```" ) class DummyObject(type): """ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by `requires_backend` each time a user tries to access any method of that class. """ def __getattr__(cls, key): if key.startswith("_"): return super().__getattr__(cls, key) requires_backends(cls, cls._backends) # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): """ Args: Compares a library version to some requirement using a given operation. library_or_version (`str` or `packaging.version.Version`): A library name or a version to check. operation (`str`): A string representation of an operator, such as `">"` or `"<="`. requirement_version (`str`): The version to compare the library version against """ if operation not in STR_OPERATION_TO_FUNC.keys(): raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") operation = STR_OPERATION_TO_FUNC[operation] if isinstance(library_or_version, str): library_or_version = parse(importlib_metadata.version(library_or_version)) return operation(library_or_version, parse(requirement_version)) # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 def is_torch_version(operation: str, version: str): """ Args: Compares the current PyTorch version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A string version of PyTorch """ return compare_versions(parse(_torch_version), operation, version) def is_transformers_version(operation: str, version: str): """ Args: Compares the current Transformers version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A version string """ if not _transformers_available: return False return compare_versions(parse(_transformers_version), operation, version) def is_accelerate_version(operation: str, version: str): """ Args: Compares the current Accelerate version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A version string """ if not _accelerate_available: return False return compare_versions(parse(_accelerate_version), operation, version) def is_k_diffusion_version(operation: str, version: str): """ Args: Compares the current k-diffusion version to a given reference with an operation. operation (`str`): A string representation of an operator, such as `">"` or `"<="` version (`str`): A version string """ if not _k_diffusion_available: return False return compare_versions(parse(_k_diffusion_version), operation, version) class OptionalDependencyNotAvailable(BaseException): """An error indicating that an optional dependency of Diffusers was not found in the environment.""" ================================================ FILE: 6DoF/diffusers/utils/logging.py ================================================ # coding=utf-8 # Copyright 2023 Optuna, Hugging Face # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Logging utilities.""" import logging import os import sys import threading from logging import ( CRITICAL, # NOQA DEBUG, # NOQA ERROR, # NOQA FATAL, # NOQA INFO, # NOQA NOTSET, # NOQA WARN, # NOQA WARNING, # NOQA ) from typing import Optional from tqdm import auto as tqdm_lib _lock = threading.Lock() _default_handler: Optional[logging.Handler] = None log_levels = { "debug": logging.DEBUG, "info": logging.INFO, "warning": logging.WARNING, "error": logging.ERROR, "critical": logging.CRITICAL, } _default_log_level = logging.WARNING _tqdm_active = True def _get_default_logging_level(): """ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is not - fall back to `_default_log_level` """ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None) if env_level_str: if env_level_str in log_levels: return log_levels[env_level_str] else: logging.getLogger().warning( f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, " f"has to be one of: { ', '.join(log_levels.keys()) }" ) return _default_log_level def _get_library_name() -> str: return __name__.split(".")[0] def _get_library_root_logger() -> logging.Logger: return logging.getLogger(_get_library_name()) def _configure_library_root_logger() -> None: global _default_handler with _lock: if _default_handler: # This library has already configured the library root logger. return _default_handler = logging.StreamHandler() # Set sys.stderr as stream. _default_handler.flush = sys.stderr.flush # Apply our default configuration to the library root logger. library_root_logger = _get_library_root_logger() library_root_logger.addHandler(_default_handler) library_root_logger.setLevel(_get_default_logging_level()) library_root_logger.propagate = False def _reset_library_root_logger() -> None: global _default_handler with _lock: if not _default_handler: return library_root_logger = _get_library_root_logger() library_root_logger.removeHandler(_default_handler) library_root_logger.setLevel(logging.NOTSET) _default_handler = None def get_log_levels_dict(): return log_levels def get_logger(name: Optional[str] = None) -> logging.Logger: """ Return a logger with the specified name. This function is not supposed to be directly accessed unless you are writing a custom diffusers module. """ if name is None: name = _get_library_name() _configure_library_root_logger() return logging.getLogger(name) def get_verbosity() -> int: """ Return the current level for the 🤗 Diffusers' root logger as an `int`. Returns: `int`: Logging level integers which can be one of: - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - `40`: `diffusers.logging.ERROR` - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN` - `20`: `diffusers.logging.INFO` - `10`: `diffusers.logging.DEBUG` """ _configure_library_root_logger() return _get_library_root_logger().getEffectiveLevel() def set_verbosity(verbosity: int) -> None: """ Set the verbosity level for the 🤗 Diffusers' root logger. Args: verbosity (`int`): Logging level which can be one of: - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - `diffusers.logging.ERROR` - `diffusers.logging.WARNING` or `diffusers.logging.WARN` - `diffusers.logging.INFO` - `diffusers.logging.DEBUG` """ _configure_library_root_logger() _get_library_root_logger().setLevel(verbosity) def set_verbosity_info(): """Set the verbosity to the `INFO` level.""" return set_verbosity(INFO) def set_verbosity_warning(): """Set the verbosity to the `WARNING` level.""" return set_verbosity(WARNING) def set_verbosity_debug(): """Set the verbosity to the `DEBUG` level.""" return set_verbosity(DEBUG) def set_verbosity_error(): """Set the verbosity to the `ERROR` level.""" return set_verbosity(ERROR) def disable_default_handler() -> None: """Disable the default handler of the 🤗 Diffusers' root logger.""" _configure_library_root_logger() assert _default_handler is not None _get_library_root_logger().removeHandler(_default_handler) def enable_default_handler() -> None: """Enable the default handler of the 🤗 Diffusers' root logger.""" _configure_library_root_logger() assert _default_handler is not None _get_library_root_logger().addHandler(_default_handler) def add_handler(handler: logging.Handler) -> None: """adds a handler to the HuggingFace Diffusers' root logger.""" _configure_library_root_logger() assert handler is not None _get_library_root_logger().addHandler(handler) def remove_handler(handler: logging.Handler) -> None: """removes given handler from the HuggingFace Diffusers' root logger.""" _configure_library_root_logger() assert handler is not None and handler not in _get_library_root_logger().handlers _get_library_root_logger().removeHandler(handler) def disable_propagation() -> None: """ Disable propagation of the library log outputs. Note that log propagation is disabled by default. """ _configure_library_root_logger() _get_library_root_logger().propagate = False def enable_propagation() -> None: """ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent double logging if the root logger has been configured. """ _configure_library_root_logger() _get_library_root_logger().propagate = True def enable_explicit_format() -> None: """ Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows: ``` [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE ``` All handlers currently bound to the root logger are affected by this method. """ handlers = _get_library_root_logger().handlers for handler in handlers: formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s") handler.setFormatter(formatter) def reset_format() -> None: """ Resets the formatting for 🤗 Diffusers' loggers. All handlers currently bound to the root logger are affected by this method. """ handlers = _get_library_root_logger().handlers for handler in handlers: handler.setFormatter(None) def warning_advice(self, *args, **kwargs): """ This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this warning will not be printed """ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) if no_advisory_warnings: return self.warning(*args, **kwargs) logging.Logger.warning_advice = warning_advice class EmptyTqdm: """Dummy tqdm which doesn't do anything.""" def __init__(self, *args, **kwargs): # pylint: disable=unused-argument self._iterator = args[0] if args else None def __iter__(self): return iter(self._iterator) def __getattr__(self, _): """Return empty function.""" def empty_fn(*args, **kwargs): # pylint: disable=unused-argument return return empty_fn def __enter__(self): return self def __exit__(self, type_, value, traceback): return class _tqdm_cls: def __call__(self, *args, **kwargs): if _tqdm_active: return tqdm_lib.tqdm(*args, **kwargs) else: return EmptyTqdm(*args, **kwargs) def set_lock(self, *args, **kwargs): self._lock = None if _tqdm_active: return tqdm_lib.tqdm.set_lock(*args, **kwargs) def get_lock(self): if _tqdm_active: return tqdm_lib.tqdm.get_lock() tqdm = _tqdm_cls() def is_progress_bar_enabled() -> bool: """Return a boolean indicating whether tqdm progress bars are enabled.""" global _tqdm_active return bool(_tqdm_active) def enable_progress_bar(): """Enable tqdm progress bar.""" global _tqdm_active _tqdm_active = True def disable_progress_bar(): """Disable tqdm progress bar.""" global _tqdm_active _tqdm_active = False ================================================ FILE: 6DoF/diffusers/utils/model_card_template.md ================================================ --- {{ card_data }} --- # {{ model_name | default("Diffusion Model") }} ## Model description This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library on the `{{ dataset_name }}` dataset. ## Intended uses & limitations #### How to use ```python # TODO: add an example code snippet for running this diffusion pipeline ``` #### Limitations and bias [TODO: provide examples of latent issues and potential remediations] ## Training data [TODO: describe the data used to train the model] ### Training hyperparameters The following hyperparameters were used during training: - learning_rate: {{ learning_rate }} - train_batch_size: {{ train_batch_size }} - eval_batch_size: {{ eval_batch_size }} - gradient_accumulation_steps: {{ gradient_accumulation_steps }} - optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }} - lr_scheduler: {{ lr_scheduler }} - lr_warmup_steps: {{ lr_warmup_steps }} - ema_inv_gamma: {{ ema_inv_gamma }} - ema_inv_gamma: {{ ema_power }} - ema_inv_gamma: {{ ema_max_decay }} - mixed_precision: {{ mixed_precision }} ### Training results 📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars) ================================================ FILE: 6DoF/diffusers/utils/outputs.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Generic utilities """ from collections import OrderedDict from dataclasses import fields from typing import Any, Tuple import numpy as np from .import_utils import is_torch_available def is_tensor(x): """ Tests if `x` is a `torch.Tensor` or `np.ndarray`. """ if is_torch_available(): import torch if isinstance(x, torch.Tensor): return True return isinstance(x, np.ndarray) class BaseOutput(OrderedDict): """ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular Python dictionary. You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple first. """ def __post_init__(self): class_fields = fields(self) # Safety and consistency checks if not len(class_fields): raise ValueError(f"{self.__class__.__name__} has no fields.") first_field = getattr(self, class_fields[0].name) other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) if other_fields_are_none and isinstance(first_field, dict): for key, value in first_field.items(): self[key] = value else: for field in class_fields: v = getattr(self, field.name) if v is not None: self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") def setdefault(self, *args, **kwargs): raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") def pop(self, *args, **kwargs): raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") def update(self, *args, **kwargs): raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") def __getitem__(self, k): if isinstance(k, str): inner_dict = dict(self.items()) return inner_dict[k] else: return self.to_tuple()[k] def __setattr__(self, name, value): if name in self.keys() and value is not None: # Don't call self.__setitem__ to avoid recursion errors super().__setitem__(name, value) super().__setattr__(name, value) def __setitem__(self, key, value): # Will raise a KeyException if needed super().__setitem__(key, value) # Don't call self.__setattr__ to avoid recursion errors super().__setattr__(key, value) def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not `None`. """ return tuple(self[k] for k in self.keys()) ================================================ FILE: 6DoF/diffusers/utils/pil_utils.py ================================================ import PIL.Image import PIL.ImageOps from packaging import version from PIL import Image if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): PIL_INTERPOLATION = { "linear": PIL.Image.Resampling.BILINEAR, "bilinear": PIL.Image.Resampling.BILINEAR, "bicubic": PIL.Image.Resampling.BICUBIC, "lanczos": PIL.Image.Resampling.LANCZOS, "nearest": PIL.Image.Resampling.NEAREST, } else: PIL_INTERPOLATION = { "linear": PIL.Image.LINEAR, "bilinear": PIL.Image.BILINEAR, "bicubic": PIL.Image.BICUBIC, "lanczos": PIL.Image.LANCZOS, "nearest": PIL.Image.NEAREST, } def pt_to_pil(images): """ Convert a torch image to a PIL image. """ images = (images / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).float().numpy() images = numpy_to_pil(images) return images def numpy_to_pil(images): """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [Image.fromarray(image) for image in images] return pil_images ================================================ FILE: 6DoF/diffusers/utils/testing_utils.py ================================================ import inspect import logging import multiprocessing import os import random import re import tempfile import unittest import urllib.parse from distutils.util import strtobool from io import BytesIO, StringIO from pathlib import Path from typing import List, Optional, Union import numpy as np import PIL.Image import PIL.ImageOps import requests from packaging import version from .import_utils import ( BACKENDS_MAPPING, is_compel_available, is_flax_available, is_note_seq_available, is_onnx_available, is_opencv_available, is_torch_available, is_torch_version, is_torchsde_available, ) from .logging import get_logger global_rng = random.Random() logger = get_logger(__name__) if is_torch_available(): import torch if "DIFFUSERS_TEST_DEVICE" in os.environ: torch_device = os.environ["DIFFUSERS_TEST_DEVICE"] available_backends = ["cuda", "cpu", "mps"] if torch_device not in available_backends: raise ValueError( f"unknown torch backend for diffusers tests: {torch_device}. Available backends are:" f" {available_backends}" ) logger.info(f"torch_device overrode to {torch_device}") else: torch_device = "cuda" if torch.cuda.is_available() else "cpu" is_torch_higher_equal_than_1_12 = version.parse( version.parse(torch.__version__).base_version ) >= version.parse("1.12") if is_torch_higher_equal_than_1_12: # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details mps_backend_registered = hasattr(torch.backends, "mps") torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device def torch_all_close(a, b, *args, **kwargs): if not is_torch_available(): raise ValueError("PyTorch needs to be installed to use this function.") if not torch.allclose(a, b, *args, **kwargs): assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}." return True def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_name="expected_slice"): test_name = os.environ.get("PYTEST_CURRENT_TEST") if not torch.is_tensor(tensor): tensor = torch.from_numpy(tensor) tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "") # format is usually: # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161]) output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array") test_file, test_class, test_fn = test_name.split("::") test_fn = test_fn.split()[0] with open(filename, "a") as f: print(";".join([test_file, test_class, test_fn, output_str]), file=f) def get_tests_dir(append_path=None): """ Args: append_path: optional path to append to the tests dir path Return: The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is joined after the `tests` dir the former is provided. """ # this function caller's __file__ caller__file__ = inspect.stack()[1][1] tests_dir = os.path.abspath(os.path.dirname(caller__file__)) while not tests_dir.endswith("tests"): tests_dir = os.path.dirname(tests_dir) if append_path: return os.path.join(tests_dir, append_path) else: return tests_dir def parse_flag_from_env(key, default=False): try: value = os.environ[key] except KeyError: # KEY isn't set, default to `default`. _value = default else: # KEY is set, convert it to True or False. try: _value = strtobool(value) except ValueError: # More values are supported, but let's keep the message simple. raise ValueError(f"If set, {key} must be yes or no.") return _value _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False) def floats_tensor(shape, scale=1.0, rng=None, name=None): """Creates a random float32 tensor""" if rng is None: rng = global_rng total_dims = 1 for dim in shape: total_dims *= dim values = [] for _ in range(total_dims): values.append(rng.random() * scale) return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() def slow(test_case): """ Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) def nightly(test_case): """ Decorator marking a test that runs nightly in the diffusers CI. Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. """ return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) def require_torch(test_case): """ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. """ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) def require_torch_2(test_case): """ Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. """ return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( test_case ) def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( test_case ) def skip_mps(test_case): """Decorator marking a test to skip if torch_device is 'mps'""" return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case) def require_flax(test_case): """ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed """ return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) def require_compel(test_case): """ Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when the library is not installed. """ return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case) def require_onnxruntime(test_case): """ Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed. """ return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case) def require_note_seq(test_case): """ Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed. """ return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) def require_torchsde(test_case): """ Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed. """ return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case) def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: if isinstance(arry, str): # local_path = "/home/patrick_huggingface_co/" if local_path is not None: # local_path can be passed to correct images of tests return os.path.join(local_path, "/".join([arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]])) elif arry.startswith("http://") or arry.startswith("https://"): response = requests.get(arry) response.raise_for_status() arry = np.load(BytesIO(response.content)) elif os.path.isfile(arry): arry = np.load(arry) else: raise ValueError( f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path" ) elif isinstance(arry, np.ndarray): pass else: raise ValueError( "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a" " ndarray." ) return arry def load_pt(url: str): response = requests.get(url) response.raise_for_status() arry = torch.load(BytesIO(response.content)) return arry def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: """ Loads `image` to a PIL Image. Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. Returns: `PIL.Image.Image`: A PIL Image. """ if isinstance(image, str): if image.startswith("http://") or image.startswith("https://"): image = PIL.Image.open(requests.get(image, stream=True).raw) elif os.path.isfile(image): image = PIL.Image.open(image) else: raise ValueError( f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" ) elif isinstance(image, PIL.Image.Image): image = image else: raise ValueError( "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." ) image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image def preprocess_image(image: PIL.Image, batch_size: int): w, h = image.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) image = torch.from_numpy(image) return 2.0 * image - 1.0 def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str: if output_gif_path is None: output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name image[0].save( output_gif_path, save_all=True, append_images=image[1:], optimize=False, duration=100, loop=0, ) return output_gif_path def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: if is_opencv_available(): import cv2 else: raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video")) if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name fourcc = cv2.VideoWriter_fourcc(*"mp4v") h, w, c = video_frames[0].shape video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h)) for i in range(len(video_frames)): img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR) video_writer.write(img) return output_video_path def load_hf_numpy(path) -> np.ndarray: if not path.startswith("http://") or path.startswith("https://"): path = os.path.join( "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path) ) return load_numpy(path) # --- pytest conf functions --- # # to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once pytest_opt_registered = {} def pytest_addoption_shared(parser): """ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` option. """ option = "--make-reports" if option not in pytest_opt_registered: parser.addoption( option, action="store", default=False, help="generate report files. The value of this option is used as a prefix to report names", ) pytest_opt_registered[option] = 1 def pytest_terminal_summary_main(tr, id): """ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current directory. The report files are prefixed with the test suite name. This function emulates --duration and -rA pytest arguments. This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined there. Args: - tr: `terminalreporter` passed from `conftest.py` - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-` plugins and interfere. """ from _pytest.config import create_terminal_writer if not len(id): id = "tests" config = tr.config orig_writer = config.get_terminal_writer() orig_tbstyle = config.option.tbstyle orig_reportchars = tr.reportchars dir = "reports" Path(dir).mkdir(parents=True, exist_ok=True) report_files = { k: f"{dir}/{id}_{k}.txt" for k in [ "durations", "errors", "failures_long", "failures_short", "failures_line", "passes", "stats", "summary_short", "warnings", ] } # custom durations report # note: there is no need to call pytest --durations=XX to get this separate report # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 dlist = [] for replist in tr.stats.values(): for rep in replist: if hasattr(rep, "duration"): dlist.append(rep) if dlist: dlist.sort(key=lambda x: x.duration, reverse=True) with open(report_files["durations"], "w") as f: durations_min = 0.05 # sec f.write("slowest durations\n") for i, rep in enumerate(dlist): if rep.duration < durations_min: f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") break f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") def summary_failures_short(tr): # expecting that the reports were --tb=long (default) so we chop them off here to the last frame reports = tr.getreports("failed") if not reports: return tr.write_sep("=", "FAILURES SHORT STACK") for rep in reports: msg = tr._getfailureheadline(rep) tr.write_sep("_", msg, red=True, bold=True) # chop off the optional leading extra frames, leaving only the last one longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) tr._tw.line(longrepr) # note: not printing out any rep.sections to keep the report short # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. # pytest-instafail does that) # report failures with line/short/long styles config.option.tbstyle = "auto" # full tb with open(report_files["failures_long"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_failures() # config.option.tbstyle = "short" # short tb with open(report_files["failures_short"], "w") as f: tr._tw = create_terminal_writer(config, f) summary_failures_short(tr) config.option.tbstyle = "line" # one line per error with open(report_files["failures_line"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_failures() with open(report_files["errors"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_errors() with open(report_files["warnings"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_warnings() # normal warnings tr.summary_warnings() # final warnings tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) with open(report_files["passes"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_passes() with open(report_files["summary_short"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.short_test_summary() with open(report_files["stats"], "w") as f: tr._tw = create_terminal_writer(config, f) tr.summary_stats() # restore: tr._tw = orig_writer tr.reportchars = orig_reportchars config.option.tbstyle = orig_tbstyle # Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787 def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): """ To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. Args: test_case (`unittest.TestCase`): The test that will run `target_func`. target_func (`Callable`): The function implementing the actual testing logic. inputs (`dict`, *optional*, defaults to `None`): The inputs that will be passed to `target_func` through an (input) queue. timeout (`int`, *optional*, defaults to `None`): The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. """ if timeout is None: timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) start_methohd = "spawn" ctx = multiprocessing.get_context(start_methohd) input_queue = ctx.Queue(1) output_queue = ctx.JoinableQueue(1) # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. input_queue.put(inputs, timeout=timeout) process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) process.start() # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents # the test to exit properly. try: results = output_queue.get(timeout=timeout) output_queue.task_done() except Exception as e: process.terminate() test_case.fail(e) process.join(timeout=timeout) if results["error"] is not None: test_case.fail(f'{results["error"]}') class CaptureLogger: """ Args: Context manager to capture `logging` streams logger: 'logging` logger object Returns: The captured output is available via `self.out` Example: ```python >>> from diffusers.utils import logging >>> from diffusers.testing_utils import CaptureLogger >>> msg = "Testing 1, 2, 3" >>> logging.set_verbosity_info() >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") >>> with CaptureLogger(logger) as cl: ... logger.info(msg) >>> assert cl.out, msg + "\n" ``` """ def __init__(self, logger): self.logger = logger self.io = StringIO() self.sh = logging.StreamHandler(self.io) self.out = "" def __enter__(self): self.logger.addHandler(self.sh) return self def __exit__(self, *exc): self.logger.removeHandler(self.sh) self.out = self.io.getvalue() def __repr__(self): return f"captured: {self.out}\n" def enable_full_determinism(): """ Helper function for reproducible behavior during distributed training. See - https://pytorch.org/docs/stable/notes/randomness.html for pytorch """ # Enable PyTorch deterministic mode. This potentially requires either the environment # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, # depending on the CUDA version, so we set them both here os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" torch.use_deterministic_algorithms(True) # Enable CUDNN deterministic mode torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False def disable_full_determinism(): os.environ["CUDA_LAUNCH_BLOCKING"] = "0" os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" torch.use_deterministic_algorithms(False) ================================================ FILE: 6DoF/diffusers/utils/torch_utils.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch utilities: Utilities related to PyTorch """ from typing import List, Optional, Tuple, Union from . import logging from .import_utils import is_torch_available, is_torch_version if is_torch_available(): import torch logger = logging.get_logger(__name__) # pylint: disable=invalid-name try: from torch._dynamo import allow_in_graph as maybe_allow_in_graph except (ImportError, ModuleNotFoundError): def maybe_allow_in_graph(cls): return cls def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, device: Optional["torch.device"] = None, dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): """A helper function to create random tensors on the desired `device` with the desired `dtype`. When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor is always created on the CPU. """ # device on which tensor is created defaults to device rand_device = device batch_size = shape[0] layout = layout or torch.strided device = device or torch.device("cpu") if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" if device != "mps": logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" f" slighly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents def is_compiled_module(module): """Check whether the module was compiled with torch.compile()""" if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): return False return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) ================================================ FILE: 6DoF/instructions.md ================================================ TODO ================================================ FILE: 6DoF/pipeline_zero1to3.py ================================================ # A diffuser version implementation of Zero1to3 (https://github.com/cvlab-columbia/zero123), ICCV 2023 # by Xin Kong import inspect from typing import Any, Callable, Dict, List, Optional, Union import torch from packaging import version from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection, ConvNextV2Model, AutoImageProcessor from CN_encoder import CN_encoder # todo import convnext from torchvision import transforms import einops # from ...configuration_utils import FrozenDict # from ...models import AutoencoderKL, UNet2DConditionModel # from ...schedulers import KarrasDiffusionSchedulers # from ...utils import ( # deprecate, # is_accelerate_available, # is_accelerate_version, # logging, # randn_tensor, # replace_example_docstring, # ) # from ..pipeline_utils import DiffusionPipeline # from . import StableDiffusionPipelineOutput # from .safety_checker import StableDiffusionSafetyChecker from unet_2d_condition import UNet2DConditionModel from diffusers import AutoencoderKL, DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( deprecate, is_accelerate_available, is_accelerate_version, randn_tensor, replace_example_docstring, ) from diffusers.utils import logging from diffusers.configuration_utils import FrozenDict import PIL import numpy as np import kornia from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name # todo EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import StableDiffusionPipeline >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" >>> image = pipe(prompt).images[0] ``` """ class CCProjection(ModelMixin, ConfigMixin): @register_to_config def __init__(self, in_channel=772, out_channel=768): super().__init__() self.in_channel = in_channel self.out_channel = out_channel self.projection = torch.nn.Linear(in_channel, out_channel) def forward(self, x): return self.projection(x) class CLIPProjection(ModelMixin, ConfigMixin): @register_to_config def __init__(self, in_channel, out_channel): super().__init__() self.in_channel = in_channel self.out_channel = out_channel # self.post_layernorm = torch.nn.LayerNorm(in_channel) self.visual_projection = torch.nn.Linear(in_channel, out_channel, bias=False) def forward(self, x): # x = self.post_layernorm(x) return self.visual_projection(x) class CNLayernorm(ModelMixin, ConfigMixin): @register_to_config def __init__(self, in_channel, eps): super().__init__() self.in_channel = in_channel self.layernorm = torch.nn.LayerNorm(in_channel, eps=eps) def forward(self, x): return self.layernorm(x) class Zero1to3StableDiffusionPipeline(DiffusionPipeline): r""" Pipeline for single view conditioned novel view generation using Zero1to3. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, image_encoder: CN_encoder, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: AutoImageProcessor, # cc_projection: CCProjection, # CLIP_projection: CLIPProjection, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, # cc_projection=cc_projection, # CLIP_projection=CLIP_projection, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) # self.model_mode = None self.ConvNextV2_preprocess = transforms.Compose([ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), # transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # We'll offload the last model manually. self.final_offload_hook = hook @property def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def CLIP_preprocess(self, x): dtype = x.dtype # following openai's implementation # TODO HF OpenAI CLIP preprocessing issue https://github.com/huggingface/transformers/issues/22505#issuecomment-1650170741 # follow openai preprocessing to keep exact same, input tensor [-1, 1], otherwise the preprocessing will be different, https://github.com/huggingface/transformers/pull/22608 if isinstance(x, torch.Tensor): if x.min() < -1.0 or x.max() > 1.0: raise ValueError("Expected input tensor to have values in the range [-1, 1]") x = kornia.geometry.resize(x.to(torch.float32), (224, 224), interpolation='bicubic', align_corners=True, antialias=False).to(dtype=dtype) x = (x + 1.) / 2. # renormalize according to clip x = kornia.enhance.normalize(x, torch.Tensor([0.48145466, 0.4578275, 0.40821073]), torch.Tensor([0.26862954, 0.26130258, 0.27577711])) return x # from image_variation def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) assert image.ndim == 4, "Image must have 4 dimensions" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 image = image.to(device=device, dtype=dtype) # image = self.CLIP_preprocess(image) # todo # if not isinstance(image, torch.Tensor): # # 0-255 # print("Warning: image is processed by hf's preprocess, which is different from openai original's.") # image = self.feature_extractor(images=image, return_tensors="pt").pixel_values # image_embeddings = self.image_encoder(image).image_embeds.to(dtype=dtype) # image_embeddings = image_embeddings.unsqueeze(1) # clip_embeddings = self.image_encoder(image).last_hidden_state.to(dtype=dtype)[:, 1:, :] # bt,257,1024 # image_embeddings = self.CLIP_projection(clip_embeddings).to(dtype=dtype) # bt,256,768 # todo # [-1, 1] -> [0, 1] image = (image + 1.) / 2. image = self.ConvNextV2_preprocess(image) image_embeddings = self.image_encoder(image)#.last_hidden_state # bt, 768, 12, 12 # image_embeddings = einops.rearrange(image_embeddings, 'b c h w -> b (h w) c') # image_embeddings = self.CN_layernorm(image_embeddings) # todo # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # import pdb; pdb.set_trace() # todo debug clip_embeddings bf16, CLIP_projection.layer_norm.weight bf16, but get float32, and after visual_projection, get fp16 rather than bf16 if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeddings) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # def _encode_pose(self, pose, device, num_images_per_prompt, do_classifier_free_guidance): # dtype = next(self.cc_projection.parameters()).dtype # if isinstance(pose, torch.Tensor): # pose_embeddings = pose.unsqueeze(1).to(device=device, dtype=dtype) # else: # if isinstance(pose[0], list): # pose = torch.Tensor(pose) # else: # pose = torch.Tensor([pose]) # x, y, z = pose[:,0].unsqueeze(1), pose[:,1].unsqueeze(1), pose[:,2].unsqueeze(1) # pose_embeddings = torch.cat([torch.deg2rad(x), # torch.sin(torch.deg2rad(y)), # torch.cos(torch.deg2rad(y)), # z], dim=-1).unsqueeze(1).to(device=device, dtype=dtype) # B, 1, 4 # # duplicate pose embeddings for each generation per prompt, using mps friendly method # bs_embed, seq_len, _ = pose_embeddings.shape # pose_embeddings = pose_embeddings.repeat(1, num_images_per_prompt, 1) # pose_embeddings = pose_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # if do_classifier_free_guidance: # negative_prompt_embeds = torch.zeros_like(pose_embeddings) # # # For classifier free guidance, we need to do two forward passes. # # Here we concatenate the unconditional and text embeddings into a single batch # # to avoid doing two forward passes # pose_embeddings = torch.cat([negative_prompt_embeds, pose_embeddings]) # return pose_embeddings # def _encode_image_with_pose(self, image, pose, device, num_images_per_prompt, do_classifier_free_guidance, t_in): # img_prompt_embeds = self._encode_image(image, device, num_images_per_prompt, False) # pose_prompt_embeds = self._encode_pose(pose, device, num_images_per_prompt, False) # pose_prompt_embeds = einops.repeat(pose_prompt_embeds, 'bt l c -> bt (repeat l) c', repeat=img_prompt_embeds.shape[1]) # prompt_embeds = torch.cat([img_prompt_embeds, pose_prompt_embeds], dim=-1) # prompt_embeds = self.cc_projection(prompt_embeds) # if self.CLIP_projection is not None: # todo for multiple generation # prompt_embeds = einops.rearrange(prompt_embeds, '(b t) l c -> b (t l) c', t=t_in) # # prompt_embeds = self.ConditionEncoder(prompt_embeds.squeeze(-2)).unsqueeze(-2) # # follow 0123, add negative prompt, after projection # if do_classifier_free_guidance: # negative_prompt = torch.zeros_like(prompt_embeds) # prompt_embeds = torch.cat([negative_prompt, prompt_embeds]) # return prompt_embeds def run_safety_checker(self, image, device, dtype): if self.safety_checker is not None: safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) else: has_nsfw_concept = None return image, has_nsfw_concept def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, image, height, width, callback_steps): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list) ): raise ValueError( "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" f" {type(image)}" ) if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def prepare_img_latents(self, image, batch_size, dtype, device, generator=None, do_classifier_free_guidance=False, t_in=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" image = image.unsqueeze(0) assert image.ndim == 4, "Image must have 4 dimensions" # Check image is in [-1, 1] if image.min() < -1 or image.max() > 1: raise ValueError("Image should be in [-1, 1] range") else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 image = image.to(device=device, dtype=dtype) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i]) for i in range(batch_size) # sample ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.mode() # init_latents = self.vae.config.scaling_factor * init_latents # todo in original zero123's inference gradio_new.py, model.encode_first_stage() is not scaled by scaling_factor if batch_size > init_latents.shape[0]: # init_latents = init_latents.repeat(batch_size // init_latents.shape[0], 1, 1, 1) num_images_per_prompt = batch_size // init_latents.shape[0] # duplicate image latents for each generation per prompt, using mps friendly method bs_embed, emb_c, emb_h, emb_w = init_latents.shape init_latents = init_latents.unsqueeze(1) init_latents = init_latents.repeat(1, num_images_per_prompt, 1, 1, 1) init_latents = init_latents.view(bs_embed * num_images_per_prompt, emb_c, emb_h, emb_w) # if self.InputEncoder is not None: # init_latents = einops.rearrange(init_latents, '(b t) c h w -> b t c h w', t=t_in) # init_latents = self.InputEncoder(init_latents) # init_latents = torch.cat([init_latents]*2) if do_classifier_free_guidance else init_latents # follow zero123 init_latents = torch.cat([torch.zeros_like(init_latents), init_latents]) if do_classifier_free_guidance else init_latents init_latents = init_latents.to(device=device, dtype=dtype) return init_latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, input_imgs: Union[torch.FloatTensor, PIL.Image.Image] = None, prompt_imgs: Union[torch.FloatTensor, PIL.Image.Image] = None, poses: Optional = None, # projections: Union[List] = None, torch_dtype=torch.float32, height: Optional[int] = None, width: Optional[int] = None, T_in: Optional[int] = None, T_out: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 3.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: float = 1.0, ): r""" Function invoked when calling the pipeline for generation. Args: input_imgs (`PIL` or `List[PIL]`, *optional*): The single input image for each 3D object prompt_imgs (`PIL` or `List[PIL]`, *optional*): Same as input_imgs, but will be used later as an image prompt condition, encoded by CLIP feature height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor assert T_out == poses[0][0].shape[1] # 1. Check inputs. Raise error if not correct # input_image = hint_imgs self.check_inputs(input_imgs, height, width, callback_steps) # # todo hard code # self.proj3d = Proj3DVolume(volume_dims=[], feature_dims=[], T_in=1, T_out=1, bound=1.0) # todo T_in=1 # 2. Define call parameters if isinstance(input_imgs, PIL.Image.Image): batch_size = 1 elif isinstance(input_imgs, list): batch_size = len(input_imgs) else: batch_size = input_imgs.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input image with pose as prompt # prompt_embeds = self._encode_image_with_pose(prompt_imgs, poses, device, num_images_per_prompt, do_classifier_free_guidance, t_in) prompt_embeds = self._encode_image(prompt_imgs, device, num_images_per_prompt, do_classifier_free_guidance) prompt_embeds = einops.rearrange(prompt_embeds, '(b t) l c -> b (t l) c', t=T_in) if do_classifier_free_guidance: [pose_out, pose_out_inv], [pose_in, pose_in_inv] = poses pose_in = torch.cat([pose_in] * 2) pose_out = torch.cat([pose_out] * 2) pose_in_inv = torch.cat([pose_in_inv] * 2) pose_out_inv = torch.cat([pose_out_inv] * 2) poses = [[pose_out, pose_out_inv], [pose_in, pose_in_inv]] # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables latents = self.prepare_latents( batch_size // T_in * T_out * num_images_per_prompt, # todo use t_out 4, height, width, prompt_embeds.dtype, device, generator, latents, )# todo same init noise along T? # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # latent_model_input = torch.cat([latent_model_input, img_latents], dim=1) latent_model_input = torch.cat([latent_model_input], dim=1) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, pose=poses).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype) latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 8. Post-processing has_nsfw_concept = None if output_type == "latent": image = latents elif output_type == "pil": # 8. Post-processing image = self.decode_latents(latents) # 10. Convert to PIL image = self.numpy_to_pil(image) else: # 8. Post-processing image = self.decode_latents(latents) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: 6DoF/train_eschernet.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import argparse import copy import logging import math import os import shutil from pathlib import Path import einops import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed, DistributedDataParallelKwargs from dataset import ObjaverseData from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from CN_encoder import CN_encoder import diffusers from diffusers import ( AutoencoderKL, DDIMScheduler, DDPMScheduler, # UNet2DConditionModel, ) from unet_2d_condition import UNet2DConditionModel from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline from diffusers.optimization import get_scheduler from diffusers.utils import is_wandb_available from diffusers.utils.import_utils import is_xformers_available from diffusers.training_utils import EMAModel import torchvision import itertools # metrics import cv2 from skimage.metrics import structural_similarity as calculate_ssim import lpips LPIPS = lpips.LPIPS(net='alex', version='0.1') if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. # check_min_version("0.19.0.dev0") logger = get_logger(__name__) def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid @torch.no_grad() def log_validation(validation_dataloader, vae, image_encoder, feature_extractor, unet, args, accelerator, weight_dtype, split="val"): logger.info("Running {} validation... ".format(split)) scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") pipeline = Zero1to3StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=accelerator.unwrap_model(vae).eval(), image_encoder=accelerator.unwrap_model(image_encoder).eval(), feature_extractor=feature_extractor, unet=accelerator.unwrap_model(unet).eval(), scheduler=scheduler, safety_checker=None, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() if args.seed is None: generator = None else: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) image_logs = [] val_lpips = 0 val_ssim = 0 val_psnr = 0 val_loss = 0 val_num = 0 T_out = args.T_out # fix to be 1? for T_in_val in [1, args.T_in_val//2, args.T_in_val]: # eval different number of given views for valid_step, batch in tqdm(enumerate(validation_dataloader)): if args.num_validation_batches is not None and valid_step >= args.num_validation_batches: break T_in = T_in_val gt_image = batch["image_target"].to(dtype=weight_dtype) input_image = batch["image_input"].to(dtype=weight_dtype)[:, :T_in] pose_in = batch["pose_in"].to(dtype=weight_dtype)[:, :T_in] # BxTx4 pose_out = batch["pose_out"].to(dtype=weight_dtype) # BxTx4 pose_in_inv = batch["pose_in_inv"].to(dtype=weight_dtype)[:, :T_in] # BxTx4 pose_out_inv = batch["pose_out_inv"].to(dtype=weight_dtype) # BxTx4 gt_image = einops.rearrange(gt_image, 'b t c h w -> (b t) c h w', t=T_out) input_image = einops.rearrange(input_image, 'b t c h w -> (b t) c h w', t=T_in) # T_in images = [] h, w = input_image.shape[2:] for _ in range(args.num_validation_images): with torch.autocast("cuda"): image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]], height=h, width=w, T_in=T_in, T_out=pose_out.shape[1], guidance_scale=args.guidance_scale, num_inference_steps=50, generator=generator, output_type="numpy").images pred_image = torch.from_numpy(image * 2. - 1.).permute(0, 3, 1, 2) images.append(pred_image) pred_np = (image * 255).astype(np.uint8) # [0,1] gt_np = (gt_image / 2 + 0.5).clamp(0, 1) gt_np = (gt_np.cpu().permute(0, 2, 3, 1).float().numpy()*255).astype(np.uint8) # for 1 image # pixel loss loss = F.mse_loss(pred_image[0], gt_image[0].cpu()).item() # LPIPS lpips = LPIPS(pred_image[0], gt_image[0].cpu()).item() # [-1, 1] torch tensor # SSIM ssim = calculate_ssim(pred_np[0], gt_np[0], channel_axis=2) # PSNR psnr = cv2.PSNR(gt_np[0], pred_np[0]) val_loss += loss val_lpips += lpips val_ssim += ssim val_psnr += psnr val_num += 1 image_logs.append( {"gt_image": gt_image, "pred_images": images, "input_image": input_image} ) pixel_loss = val_loss / val_num pixel_lpips= val_lpips / val_num pixel_ssim = val_ssim / val_num pixel_psnr = val_psnr / val_num for tracker in accelerator.trackers: if tracker.name == "wandb": # need to use table, wandb doesn't allow more than 108 images assert args.num_validation_images == 2 table = wandb.Table(columns=["Input", "GT", "Pred1", "Pred2"]) for log_id, log in enumerate(image_logs): formatted_images = [[], [], []] # [[input], [gt], [pred]] pred_images = log["pred_images"] # pred input_image = log["input_image"] # input gt_image = log["gt_image"] # GT formatted_images[0].append(wandb.Image(input_image, caption="{}_input".format(log_id))) formatted_images[1].append(wandb.Image(gt_image, caption="{}_gt".format(log_id))) for sample_id, pred_image in enumerate(pred_images): # n_samples pred_image = wandb.Image(pred_image, caption="{}_pred_{}".format(log_id, sample_id)) formatted_images[2].append(pred_image) table.add_data(*formatted_images[0], *formatted_images[1], *formatted_images[2]) tracker.log({split: table, # formatted_images "{}_T{}_pixel_loss".format(split, T_in_val): pixel_loss, "{}_T{}_lpips".format(split, T_in_val): pixel_lpips, "{}_T{}_ssim".format(split, T_in_val): pixel_ssim, "{}_T{}_psnr".format(split, T_in_val): pixel_psnr}) else: logger.warn(f"image logging not implemented for {tracker.name}") # del pipeline # torch.cuda.empty_cache() # after validation, set the pipeline back to training mode unet.train() vae.eval() image_encoder.train() return image_logs def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a Zero123 training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default="lambdalabs/sd-image-variations-diffusers", required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help=( "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" " float32 precision." ), ) parser.add_argument( "--output_dir", type=str, default="eschernet-6dof", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=256, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--T_in", type=int, default=1, help="Number of input views" ) parser.add_argument( "--T_in_val", type=int, default=10, help="Number of input views" ) parser.add_argument( "--T_out", type=int, default=1, help="Number of output views" ) parser.add_argument( "--max_train_steps", type=int, default=100000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--guidance_scale", type=float, default=3.0, help="unconditional guidance scale, if guidance_scale>1.0, do_classifier_free_guidance" ) parser.add_argument( "--conditioning_dropout_prob", type=float, default=0.05, help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800" ) parser.add_argument( "--checkpointing_steps", type=int, default=2000, help=( "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" "instructions." ), ) parser.add_argument( "--checkpoints_total_limit", type=int, default=20, help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--dataloader_num_workers", type=int, default=1, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=0.5, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--report_to", type=str, default="wandb", # log_image currently only for wandb help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--enable_xformers_memory_efficient_attention", default=True, help="Whether or not to use xformers." ) parser.add_argument( "--set_grads_to_none", default=True, help=( "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" " behaviors, so disable this argument if it causes any problems. More info:" " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" ), ) parser.add_argument( "--dataset_name", type=str, default=None, help=( "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," " or to a folder containing files that 🤗 Datasets can understand." ), ) parser.add_argument( "--dataset_config_name", type=str, default=None, help="The config of the Dataset, leave as None if there's only one config.", ) parser.add_argument( "--train_data_dir", type=str, default=None, help=( "A folder containing the training data. Folder contents must follow the structure described in" " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." ), ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument( "--num_validation_images", type=int, default=2, help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", ) parser.add_argument( "--validation_steps", type=int, default=2000, help=( "Run validation every X steps. Validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`" " and logging the images." ), ) parser.add_argument( "--num_validation_batches", type=int, default=20, help=( "Number of batches to use for validation. If `None`, use all batches." ), ) parser.add_argument( "--tracker_project_name", type=str, default="train_zero123_hf", help=( "The `project_name` argument passed to Accelerator.init_trackers for" " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") if args.dataset_name is not None and args.train_data_dir is not None: raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") if args.resolution % 8 != 0: raise ValueError( "`--resolution` must be divisible by 8 for consistently sized encoded images." ) return args ConvNextV2_preprocess = transforms.Compose([ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def _encode_image(feature_extractor, image_encoder, image, device, dtype, do_classifier_free_guidance): # [-1, 1] -> [0, 1] image = (image + 1.) / 2. image = ConvNextV2_preprocess(image) image_embeddings = image_encoder(image) # bt, 768, 12, 12 if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeddings) image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings #.detach() # !we need keep image encoder gradient def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token, private=True ).repo_id # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision) image_encoder = CN_encoder.from_pretrained("facebook/convnextv2-tiny-22k-224") feature_extractor = None vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision) T_in = args.T_in T_in_val = args.T_in_val T_out = args.T_out vae.eval() vae.requires_grad_(False) image_encoder.train() image_encoder.requires_grad_(True) unet.requires_grad_(True) unet.train() # Create EMA for the unet. if args.use_ema: ema_unet = EMAModel(unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warn( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() vae.enable_slicing() else: raise ValueError("xformers is not available. Make sure it is installed correctly") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() # Check that all trainable models are in full precision low_precision_error_string = ( " Please make sure to always have all model weights in full float32 precision when starting training - even if" " doing mixed precision training, copy of the weights should still be float32." ) if accelerator.unwrap_model(unet).dtype != torch.float32: raise ValueError( f"UNet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW optimizer = optimizer_class( [{"params": unet.parameters(), "lr": args.learning_rate}, {"params": image_encoder.parameters(), "lr": args.learning_rate}], betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon ) # print model info, learnable parameters, non-learnable parameters, total parameters, model size, all in billion def print_model_info(model): print("="*20) # print model class name print("model name: ", type(model).__name__) print("learnable parameters(M): ", sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6) print("non-learnable parameters(M): ", sum(p.numel() for p in model.parameters() if not p.requires_grad) / 1e6) print("total parameters(M): ", sum(p.numel() for p in model.parameters()) / 1e6) print("model size(MB): ", sum(p.numel() * p.element_size() for p in model.parameters()) / 1024 / 1024) print_model_info(unet) print_model_info(vae) print_model_info(image_encoder) # Init Dataset image_transforms = torchvision.transforms.Compose( [ torchvision.transforms.Resize((args.resolution, args.resolution)), # 256, 256 transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ] ) train_dataset = ObjaverseData(root_dir=args.train_data_dir, image_transforms=image_transforms, validation=False, T_in=T_in, T_out=T_out) train_log_dataset = ObjaverseData(root_dir=args.train_data_dir, image_transforms=image_transforms, validation=False, T_in=T_in_val, T_out=T_out, fix_sample=True) validation_dataset = ObjaverseData(root_dir=args.train_data_dir, image_transforms=image_transforms, validation=True, T_in=T_in_val, T_out=T_out, fix_sample=True) # for training train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) # for validation set logs validation_dataloader = torch.utils.data.DataLoader( validation_dataset, shuffle=False, batch_size=1, num_workers=1, ) # for training set logs train_log_dataloader = torch.utils.data.DataLoader( train_log_dataset, shuffle=False, batch_size=1, num_workers=1, ) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): """Warmup the learning rate""" lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) for param_group in optimizer.param_groups: param_group['lr'] = lr def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): """Decay the learning rate""" lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr for param_group in optimizer.param_groups: param_group['lr'] = lr # Prepare everything with our `accelerator`. unet, image_encoder, optimizer, train_dataloader, validation_dataloader, train_log_dataloader = accelerator.prepare( unet, image_encoder, optimizer, train_dataloader, validation_dataloader, train_log_dataloader ) if args.use_ema: ema_unet.to(accelerator.device) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move vae, image_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = dict(vars(args)) run_name = args.output_dir.split("logs_")[1] accelerator.init_trackers(args.tracker_project_name, config=tracker_config, init_kwargs={"wandb":{"name":run_name}}) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps do_classifier_free_guidance = args.guidance_scale > 1.0 logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") logger.info(f" do_classifier_free_guidance = {do_classifier_free_guidance}") logger.info(f" conditioning_dropout_prob = {args.conditioning_dropout_prob}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) for epoch in range(first_epoch, args.num_train_epochs): loss_epoch = 0.0 num_train_elems = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet, image_encoder): gt_image = batch["image_target"].to(dtype=weight_dtype) # BxTx3xHxW gt_image = einops.rearrange(gt_image, 'b t c h w -> (b t) c h w', t=T_out) input_image = batch["image_input"].to(dtype=weight_dtype) # Bx3xHxW input_image = einops.rearrange(input_image, 'b t c h w -> (b t) c h w', t=T_in) pose_in = batch["pose_in"].to(dtype=weight_dtype) # BxTx4 pose_out = batch["pose_out"].to(dtype=weight_dtype) # BxTx4 pose_in_inv = batch["pose_in_inv"].to(dtype=weight_dtype) # BxTx4 pose_out_inv = batch["pose_out_inv"].to(dtype=weight_dtype) # BxTx4 gt_latents = vae.encode(gt_image).latent_dist.sample().detach() gt_latents = gt_latents * vae.config.scaling_factor # follow zero123, only target image latent is scaled # Sample noise that we'll add to the latents bsz = gt_latents.shape[0] // T_out noise = torch.randn_like(gt_latents) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=gt_latents.device) timesteps = timesteps.long() timesteps = einops.repeat(timesteps, 'b -> (b t)', t=T_out) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(gt_latents.to(dtype=torch.float32), noise.to(dtype=torch.float32), timesteps).to(dtype=gt_latents.dtype) if do_classifier_free_guidance: #support classifier-free guidance, randomly drop out 5% # Conditioning dropout to support classifier-free guidance during inference. For more details # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800. random_p = torch.rand(bsz, device=gt_latents.device) # Sample masks for the edit prompts. prompt_mask = random_p < 2 * args.conditioning_dropout_prob prompt_mask = prompt_mask.reshape(bsz, 1, 1, 1) img_prompt_embeds = _encode_image(feature_extractor, image_encoder, input_image, gt_latents.device, gt_latents.dtype, False) # Final text conditioning. img_prompt_embeds = einops.rearrange(img_prompt_embeds, '(b t) l c -> b t l c', t=T_in) null_conditioning = torch.zeros_like(img_prompt_embeds).detach() img_prompt_embeds = torch.where(prompt_mask, null_conditioning, img_prompt_embeds) img_prompt_embeds = einops.rearrange(img_prompt_embeds, 'b t l c -> (b t) l c', t=T_in) prompt_embeds = torch.cat([img_prompt_embeds], dim=-1) else: # Get the image_with_pose embedding for conditioning prompt_embeds = _encode_image(feature_extractor, image_encoder, input_image, gt_latents.device, gt_latents.dtype, False) prompt_embeds = einops.rearrange(prompt_embeds, '(b t) l c -> b (t l) c', t=T_in) # noisy_latents (b T_out) latent_model_input = torch.cat([noisy_latents], dim=1) # Predict the noise residual model_pred = unet( latent_model_input, timesteps, encoder_hidden_states=prompt_embeds, # (bxT_in) l 768 pose=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]], # (bxT_in) 4, pose_out - self-attn, pose_in - cross-attn ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(gt_latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = (loss.mean([1, 2, 3])).mean() accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = itertools.chain(unet.parameters(), image_encoder.parameters()) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() # cosine if global_step <= args.lr_warmup_steps: warmup_lr_schedule(optimizer, global_step, args.lr_warmup_steps, 1e-5, args.learning_rate) else: cosine_lr_schedule(optimizer, global_step, args.max_train_steps, args.learning_rate, 1e-5) optimizer.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: ema_unet.step(unet.parameters()) progress_bar.update(1) global_step += 1 if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") # save pipeline # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: pipelines = os.listdir(args.output_dir) pipelines = [d for d in pipelines if d.startswith("pipeline")] pipelines = sorted(pipelines, key=lambda x: int(x.split("-")[1])) # before we save the new pipeline, we need to have at _most_ `checkpoints_total_limit - 1` pipeline if len(pipelines) >= args.checkpoints_total_limit: num_to_remove = len(pipelines) - args.checkpoints_total_limit + 1 removing_pipelines = pipelines[0:num_to_remove] logger.info( f"{len(pipelines)} pipelines already exist, removing {len(removing_pipelines)} pipelines" ) logger.info(f"removing pipelines: {', '.join(removing_pipelines)}") for removing_pipeline in removing_pipelines: removing_pipeline = os.path.join(args.output_dir, removing_pipeline) shutil.rmtree(removing_pipeline) if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) pipeline = Zero1to3StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=accelerator.unwrap_model(vae), image_encoder=accelerator.unwrap_model(image_encoder), feature_extractor=feature_extractor, unet=accelerator.unwrap_model(unet), scheduler=noise_scheduler, safety_checker=None, torch_dtype=torch.float32, ) pipeline_save_path = os.path.join(args.output_dir, f"pipeline-{global_step}") pipeline.save_pretrained(pipeline_save_path) # del pipeline if args.push_to_hub: print("Pushing to the hub ", repo_id) upload_folder( repo_id=repo_id, folder_path=pipeline_save_path, commit_message=global_step, ignore_patterns=["step_*", "epoch_*"], run_as_future=True, ) if args.use_ema: # Switch back to the original UNet parameters. ema_unet.restore(unet.parameters()) if validation_dataloader is not None and global_step % args.validation_steps == 0: if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) image_logs = log_validation( validation_dataloader, vae, image_encoder, feature_extractor, unet, args, accelerator, weight_dtype, 'val', ) if args.use_ema: # Switch back to the original UNet parameters. ema_unet.restore(unet.parameters()) if train_log_dataloader is not None and (global_step % args.validation_steps == 0 or global_step == 1): if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) train_image_logs = log_validation( train_log_dataloader, vae, image_encoder, feature_extractor, unet, args, accelerator, weight_dtype, 'train', ) if args.use_ema: # Switch back to the original UNet parameters. ema_unet.restore(unet.parameters()) loss_epoch += loss.detach().item() num_train_elems += 1 logs = {"loss": loss.detach().item(), "lr": optimizer.param_groups[0]['lr'], "loss_epoch": loss_epoch / num_train_elems, "epoch": epoch} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters()) pipeline = Zero1to3StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=accelerator.unwrap_model(vae), image_encoder=accelerator.unwrap_model(image_encoder), feature_extractor=feature_extractor, unet=unet, scheduler=noise_scheduler, safety_checker=None, torch_dtype=torch.float32, ) pipeline_save_path = os.path.join(args.output_dir, f"pipeline-{global_step}") pipeline.save_pretrained(pipeline_save_path) if args.push_to_hub: upload_folder( repo_id=repo_id, folder_path=pipeline_save_path, commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) accelerator.end_training() if __name__ == "__main__": # torch.multiprocessing.set_sharing_strategy("file_system") args = parse_args() main(args) ================================================ FILE: 6DoF/unet_2d_condition.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.utils import BaseOutput from diffusers.utils import logging from diffusers.models.activations import get_activation from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unet_2d_blocks import ( CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, UpBlock2D, get_down_block, get_up_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] self.up_block_out_channels = [] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.up_blocks.append(up_block) prev_output_channel = output_channel self.up_block_out_channels.append(output_channel) # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) self.block_out_channels = block_out_channels self.reversed_block_out_channels = reversed_block_out_channels @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, pose = None, # (b T_in) 4 class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, posemb=pose, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) # default down_block_res_samples += res_samples if down_block_additional_residuals is not None: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, posemb=pose, ) if mid_block_additional_residual is not None: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, posemb=pose, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) ================================================ FILE: CaPE.py ================================================ import numpy as np import einops import torch from scipy.spatial.transform import Rotation as R ########################## 6DoF CaPE #################################### class CaPE_6DoF: def cape_embed(self, f, P): """ Apply CaPE on feature. :param f: feature vector of shape [..., d] :param P: 4x4 transformation matrix :return: rotated feature f by pose P: f@P """ f = einops.rearrange(f, '... (d k) -> ... d k', k=4) return einops.rearrange(f@P, '... d k -> ... (d k)', k=4) def attn_with_CaPE(self, f1, f2, p1, p2): """ Do attention dot production with CaPE pose encoding. # query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T) # key = cape_embed(key, p_in) # key f_k @ p_in :param f1: b (t1 l) d :param f2: b (t2 l) d :param p1: [b, t, 4, 4] :param p2: [b, t, 4, 4] :return: attention score: q@k.T """ l = f1.shape[1] // p1.shape[1] assert f1.shape[1] // p1.shape[1] == f2.shape[1] // p2.shape[1] p1_invT = einops.repeat(torch.inverse(p1).permute(0, 1, 3, 2), 'b t m n -> b (t l) m n', l=l) # f1 [b, l*t1, d] query = self.cape_embed(f1, p1_invT) # [b, l*t1, d] query: f1 @ (p1)^(-T), transpose the last two dim p2_copy = einops.repeat(p2, 'b t m n -> b (t l) m n', l=l) # f2 [b, l*t2, d] key = self.cape_embed(f2, p2_copy) # [b, l*t2, d] key: f2 @ p2 att = query @ key.permute(0, 2, 1) # [b, l*t1, l*t2] attention: query@key^T return att ################### 6DoF Verification ################################### def euler_to_matrix(alpha, beta, gamma, x, y, z): # radian r = R.from_euler('xyz', [alpha, beta, gamma], degrees=True) t = np.array([[x], [y], [z]]) rot_matrix = r.as_matrix() rot_matrix = np.concatenate([rot_matrix, t], axis=-1) rot_matrix = np.concatenate([rot_matrix, [[0, 0, 0, 1]]], axis=0) return rot_matrix def random_6dof_pose(B, T): pose_euler = torch.rand([B, T, 6]).numpy() # euler pose_matrix = [] for b in range(B): p = [] for t in range(T): p.append(torch.from_numpy(euler_to_matrix(*pose_euler[b, t]))) pose_matrix.append(torch.stack(p)) pose_matrix = torch.stack(pose_matrix) return pose_matrix.float() bs = 6 # batch size t1 = 3 # num of target views in each batch, can be arbitrary number t2 = 5 # num of reference views in each batch, can be arbitrary number l = 10 # len of token d = 16 # dim of token feature, need to mod 4 in this case assert d % 4 == 0 # random init query and key f1 = torch.rand(bs, t1, l, d) # query f2 = torch.rand(bs, t2, l, d) # key f1 = einops.rearrange(f1, 'b t l d -> b (t l) d') f2 = einops.rearrange(f2, 'b t l d -> b (t l) d') # random init pose p1, p2, delta_p, [bs, t, 4, 4] p1 = random_6dof_pose(bs, t1) # [bs, t1, 4, 4] p2 = random_6dof_pose(bs, t2) # [bs, t2, 4, 4] p_delta = random_6dof_pose(bs, 1) # [bs, 1, 4, 4] # delta p is identical to p1 and p2 in each batch p1_delta = einops.repeat(p_delta, 'b 1 m n -> b (1 t) m n', t=t1//1) p2_delta = einops.repeat(p_delta, 'b 1 m n -> b (1 t) m n', t=t2//1) # run attention with CaPE 6DoF cape_6dof = CaPE_6DoF() # att att = cape_6dof.attn_with_CaPE(f1, f2, p1, p2) # att_delta att_delta = cape_6dof.attn_with_CaPE(f1, f2, p1@p1_delta, p2@p2_delta) # condition: att score should be the same i.e. non effect from any delta_p assert torch.allclose(att, att_delta, 1e-3) print("6DoF CaPE Verified") ########################## 4DoF CaPE #################################### class CaPE_4DoF: def rotate_every_two(self, x): x = einops.rearrange(x, '... (d j) -> ... d j', j=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return einops.rearrange(x, '... d j -> ... (d j)') def cape(self, x, p): d, l, n = x.shape[-1], p.shape[-2], p.shape[-1] assert d % (2 * n) == 0 m = einops.repeat(p, 'b l n -> b l (n k)', k=d // n) return m def cape_embed(self, qq, kk, p1, p2): """ Embed camera position encoding into attention map :param qq: query feature map [b, l_q, feature_dim] :param kk: key feature map [b, l_k, feature_dim] :param p1: query pose [b, l_q, pose_dim] :param p2: key pose [b, l_k, pose_dim] :return: cape embedded attention map [b, l_q, l_k] """ assert p1.shape[-1] == p2.shape[-1] assert qq.shape[-1] == kk.shape[-1] assert p1.shape[0] == p2.shape[0] == qq.shape[0] == kk.shape[0] assert p1.shape[1] == qq.shape[1] assert p2.shape[1] == kk.shape[1] m1 = self.cape(qq, p1) m2 = self.cape(kk, p2) q = (qq * m1.cos()) + (self.rotate_every_two(qq) * m1.sin()) k = (kk * m2.cos()) + (self.rotate_every_two(kk) * m2.sin()) return q, k def attn_with_CaPE(self, f1, f2, p1, p2): """ Do attention dot production with CaPE pose encoding. # query = cape_embed(query, p_out_inv) # query f_q @ (p_out)^(-T) # key = cape_embed(key, p_in) # key f_k @ p_in :param f1: b (t1 l) d :param f2: b (t2 l) d :param p1: [b, t, 4] :param p2: [b, t, 4] :return: attention score: q@k.T """ l = f1.shape[1] // p1.shape[1] assert f1.shape[1] // p1.shape[1] == f2.shape[1] // p2.shape[1] p1_reshape = einops.repeat(p1, 'b t m -> b (t l) m', l=l) # f1 [b, l*t1, d] p2_reshape = einops.repeat(p2, 'b t m -> b (t l) m', l=l) # f1 [b, l*t1, d] query, key = self.cape_embed(f1, f2, p1_reshape, p2_reshape) att = query @ key.permute(0, 2, 1) # [b, l*t1, l*t2] attention: query@key^T return att ################### 4DoF Verification ################################### def random_4dof_pose(B, T): pose = torch.zeros([B, T, 4]) pose[:, :, :3] = torch.rand([B, T, 3]) # radian angle # theta \in [0, pi], azimuth \in [0, 2pi], radius \in [0, pi], 0 pose[:, :, 1] *= (2*torch.pi) pose[:, :, 0] *= torch.pi pose[:, :, 2] *= torch.pi return pose.float() def look_at(origin, target, up): forward = (target - origin) forward = forward / torch.linalg.norm(forward, dim=-1, keepdim=True) right = torch.linalg.cross(forward, up) right = right / torch.linalg.norm(right, dim=-1, keepdim=True) new_up = torch.linalg.cross(forward, right) new_up = new_up / torch.linalg.norm(new_up, dim=-1, keepdim=True) rotation_matrix = torch.stack((right, new_up, forward, target), dim=-1) matrix = torch.cat([rotation_matrix, torch.tensor([[0, 0, 0, 1]]).repeat(rotation_matrix.shape[0],rotation_matrix.shape[1], 1, 1)], dim=-2) return matrix def pose_4dof2matrix(pose_4dof): """ :param pose_4dof: [b, t, 4] :return: pose 4x4 matrix: [b, t, 4, 4] """ theta = pose_4dof[:, :, 0] azimuth = pose_4dof[:, :, 1] radius = pose_4dof[:, :, 2] xyz = torch.stack([torch.sin(theta) * torch.cos(azimuth), torch.sin(theta) * torch.sin(azimuth), torch.cos(theta)], dim=-1) * radius.unsqueeze(-1) origin = torch.zeros_like(xyz) up = torch.zeros_like(xyz) up[:, :, 2] = 1 pose = look_at(origin, xyz, up) return pose def pose_matrix24dof(pose_matrix): """ :param pose_matrix: [b, t, 4, 4] :return: pose_4dof: [b, t, 4] theta, azimuth, radius, 0, looking at origin """ xyz = pose_matrix[..., :3, 3] xy = xyz[..., 0] ** 2 + xyz[..., 1] ** 2 radius = torch.sqrt(xy + xyz[..., 2] ** 2) theta = torch.arctan2(torch.sqrt(xy), xyz[..., 2]) # for elevation angle defined from Z-axis down azimuth = torch.arctan2(xyz[..., 1], xyz[..., 0]) pose = torch.stack([theta, azimuth, radius, torch.zeros_like(radius)], dim=-1) # move to [0, 2pi] pose %= (2 * torch.pi) return pose bs = 6 # batch size t1 = 3 # num of target views in each batch, can be arbitrary number t2 = 5 # num of reference views in each batch, can be arbitrary number l = 10 # len of token d = 16 # dim of token feature, need to mod 4 in this case # random init query and key f1 = torch.rand(bs, t1, l, d) # query f2 = torch.rand(bs, t2, l, d) # key f1 = einops.rearrange(f1, 'b t l d -> b (t l) d') f2 = einops.rearrange(f2, 'b t l d -> b (t l) d') #random init 4DoF pose [bs, t1, 4], theta, azimuth, radius, 0 p1 = random_4dof_pose(bs, t1) # [bs, t1, 4] p2 = random_4dof_pose(bs, t2) # [bs, t2, 4] p1_matrix = pose_4dof2matrix(p1) p1_4dof = pose_matrix24dof(p1_matrix) assert torch.allclose(p1, p1_4dof) p_delta_4dof = random_4dof_pose(bs, 1) # delta p is identical to p1 and p2 in each batch p1_delta_4dof = einops.repeat(p_delta_4dof, 'b 1 m -> b (1 t) m', t=t1//1) p2_delta_4dof = einops.repeat(p_delta_4dof, 'b 1 m -> b (1 t) m', t=t2//1) # run attention with CaPE 6DoF cape_4dof = CaPE_4DoF() # att att = cape_4dof.attn_with_CaPE(f1, f2, p1, p2) # att_delta att_delta = cape_4dof.attn_with_CaPE(f1, f2, p1+p1_delta_4dof, p2+p2_delta_4dof) # condition: att score should be the same i.e. non effect from any delta_p assert torch.allclose(att, att_delta, 1e-3) print("4DoF CaPE Verified") # print("You should get assertion error because 4DoF CaPE cannot handle 6DoF jitter") # # att_delta_6dof, it cannot handle 6dof jitter # p_delta_6dof = random_6dof_pose(bs, 1) # [bs, 1, 4, 4] any delta transformation in 6DoF # # delta p is identical to p1 and p2 in each batch # p1_delta_6dof = einops.repeat(p_delta_6dof, 'b 1 m n -> b (1 t) m n', t=t1//1) # p2_delta_6dof = einops.repeat(p_delta_6dof, 'b 1 m n -> b (1 t) m n', t=t2//1) # # 4dof pose to 4x4 matrix # p1_matrix = pose_4dof2matrix(p1) # p2_matrix = pose_4dof2matrix(p2) # att_delta_6dof = cape_4dof.attn_with_CaPE(f1, f2, pose_matrix24dof(p1_matrix@p1_delta_6dof), pose_matrix24dof(p2_matrix@p2_delta_6dof)) # # condition: att score should be the same i.e. non effect from any delta_p # assert torch.allclose(att, att_delta_6dof, 1e-3) ================================================ FILE: LICENSE ================================================ EscherNet SOFTWARE LICENCE AGREEMENT WE (Imperial College of Science, Technology and Medicine, (“Imperial College London”)) ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY ON THE CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE FOLLOWING AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE DOWNLOADING THE SOFTWARE. BY EXERCISING THE OPTION TO DOWNLOAD THE SOFTWARE YOU AGREE TO BE BOUND BY THE TERMS OF THE AGREEMENT. SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS) 1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully paid-up, royalty free, non-transferable, non-sub- licensable licence (the “Licence”) to use the elastic fusion source code, including any modification, part or derivative (the “Software”). Ownership and Licence. Your rights to use and download the Software onto your computer, and all other copies that You are authorised to make, are specified in this Agreement. However, we (or our licensors) retain all rights, including but not limited to all copyright and other intellectual property rights anywhere in the world, in the Software not expressly granted to You in this Agreement. 2. Permitted use of the Licence: (a) You may download and install the Software onto one computer or server for use in accordance with Clause 2(b) of this Agreement provided that You ensure that the Software is not accessible by other users unless they have themselves accepted the terms of this licence agreement. (b) You may use the Software solely for non-commercial, internal or academic research purposes and only in accordance with the terms of this Agreement. You may not use the Software for commercial purposes, including but not limited to (1) integration of all or part of the source code or the Software into a product for sale or licence by or on behalf of You to third parties or (2) use of the Software or any derivative of it for research to develop software products for sale or licence to a third party or (3) use of the Software or any derivative of it for research to develop non-software products for sale or licence to a third party, or (4) use of the Software to provide any service to an external organisation for which payment is received. Should You wish to use the Software for commercial purposes, You shall email researchcontracts.engineering@imperial.ac.uk . (c) Right to Copy. You may copy the Software for back-up and archival purposes, provided that each copy is kept in your possession and provided You reproduce our copyright notice (set out in Schedule 1) on each copy. (d) Transfer and sub-licensing. You may not rent, lend, or lease the Software and You may not transmit, transfer or sub-license this licence to use the Software or any of your rights or obligations under this Agreement to another party. (e) Identity of Licensee. The licence granted herein is personal to You. You shall not permit any third party to access, modify or otherwise use the Software nor shall You access modify or otherwise use the Software on behalf of any third party. If You wish to obtain a licence for mutiple users or a site licence for the Software please contact us at researchcontracts.engineering@imperial.ac.uk . (f) Publications and presentations. You may make public, results or data obtained from, dependent on or arising from research carried out using the Software, provided that any such presentation or publication identifies the Software as the source of the results or the data, including the Copyright Notice given in each element of the Software, and stating that the Software has been made available for use by You under licence from Imperial College London and You provide a copy of any such publication to Imperial College London. 3. Prohibited Uses. You may not, without written permission from us at researchcontracts.engineering@imperial.ac.uk : (a) Use, copy, modify, merge, or transfer copies of the Software or any documentation provided by us which relates to the Software except as provided in this Agreement; (b) Use any back-up or archival copies of the Software (or allow anyone else to use such copies) for any purpose other than to replace the original copy in the event it is destroyed or becomes defective; or (c) Disassemble, decompile or "unlock", reverse translate, or in any manner decode the Software for any reason. 4. Warranty Disclaimer (a) Disclaimer. The Software has been developed for research purposes only. You acknowledge that we are providing the Software to You under this licence agreement free of charge and on condition that the disclaimer set out below shall apply. We do not represent or warrant that the Software as to: (i) the quality, accuracy or reliability of the Software; (ii) the suitability of the Software for any particular use or for use under any specific conditions; and (iii) whether use of the Software will infringe third-party rights. You acknowledge that You have reviewed and evaluated the Software to determine that it meets your needs and that You assume all responsibility and liability for determining the suitability of the Software as fit for your particular purposes and requirements. Subject to Clause 4(b), we exclude and expressly disclaim all express and implied representations, warranties, conditions and terms not stated herein (including the implied conditions or warranties of satisfactory quality, merchantable quality, merchantability and fitness for purpose). (b) Savings. Some jurisdictions may imply warranties, conditions or terms or impose obligations upon us which cannot, in whole or in part, be excluded, restricted or modified or otherwise do not allow the exclusion of implied warranties, conditions or terms, in which case the above warranty disclaimer and exclusion will only apply to You to the extent permitted in the relevant jurisdiction and does not in any event exclude any implied warranties, conditions or terms which may not under applicable law be excluded. (c) Imperial College London disclaims all responsibility for the use which is made of the Software and any liability for the outcomes arising from using the Software. 5. Limitation of Liability (a) You acknowledge that we are providing the Software to You under this licence agreement free of charge and on condition that the limitation of liability set out below shall apply. Accordingly, subject to Clause 5(b), we exclude all liability whether in contract, tort, negligence or otherwise, in respect of the Software and/or any related documentation provided to You by us including, but not limited to, liability for loss or corruption of data, loss of contracts, loss of income, loss of profits, loss of cover and any consequential or indirect loss or damage of any kind arising out of or in connection with this licence agreement, however caused. This exclusion shall apply even if we have been advised of the possibility of such loss or damage. (b) You agree to indemnify Imperial College London and hold it harmless from and against any and all claims, damages and liabilities asserted by third parties (including claims for negligence) which arise directly or indirectly from the use of the Software or any derivative of it or the sale of any products based on the Software. You undertake to make no liability claim against any employee, student, agent or appointee of Imperial College London, in connection with this Licence or the Software. (c) Nothing in this Agreement shall have the effect of excluding or limiting our statutory liability. (d) Some jurisdictions do not allow these limitations or exclusions either wholly or in part, and, to that extent, they may not apply to you. Nothing in this licence agreement will affect your statutory rights or other relevant statutory provisions which cannot be excluded, restricted or modified, and its terms and conditions must be read and construed subject to any such statutory rights and/or provisions. 6. Confidentiality. You agree not to disclose any confidential information provided to You by us pursuant to this Agreement to any third party without our prior written consent. The obligations in this Clause 6 shall survive the termination of this Agreement for any reason. 7. Termination. (a) We may terminate this licence agreement and your right to use the Software at any time with immediate effect upon written notice to You. (b) This licence agreement and your right to use the Software automatically terminate if You: (i) fail to comply with any provisions of this Agreement; or (ii) destroy the copies of the Software in your possession, or voluntarily return the Software to us. (c) Upon termination You will destroy all copies of the Software. (d) Otherwise, the restrictions on your rights to use the Software will expire 10 (ten) years after first use of the Software under this licence agreement. 8. Miscellaneous Provisions. (a) This Agreement will be governed by and construed in accordance with the substantive laws of England and Wales whose courts shall have exclusive jurisdiction over all disputes which may arise between us. (b) This is the entire agreement between us relating to the Software, and supersedes any prior purchase order, communications, advertising or representations concerning the Software. (c) No change or modification of this Agreement will be valid unless it is in writing, and is signed by us. (d) The unenforceability or invalidity of any part of this Agreement will not affect the enforceability or validity of the remaining parts. BSD Elements of the Software For BSD elements of the Software, the following terms shall apply: Copyright as indicated in the header of the individual element of the Software. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SCHEDULE 1 The Software EscherNet is a multiview diffusion model for scalable generative any-to-any novel view synthesis. It is based on the techniques described in the following publication: • Xin Kong, Shikun Liu, Xiaoyang Lyu, Marwan Taher, Xiaojuan Qi, Andrew J. Davison. EscherNet: A Generative Model for Scalable View Synthesis. ArXiv Preprint, 2024 _________________________ Acknowledgments If you use the software, you should reference the following paper in any publication: • Xin Kong, Shikun Liu, Xiaoyang Lyu, Marwan Taher, Xiaojuan Qi, Andrew J. Davison. EscherNet: A Generative Model for Scalable View Synthesis. ArXiv Preprint, 2024 ================================================ FILE: README.md ================================================ [comment]: <> (# EscherNet: A Generative Model for Scalable View Synthesis)

EscherNet: A Generative Model for Scalable View Synthesis

Xin Kong · Shikun Liu · Xiaoyang Lyu · Marwan Taher · Xiaojuan Qi · Andrew J. Davison

[comment]: <> (

PAPER

)

Paper | Project Page

Logo

EscherNet is a multi-view conditioned diffusion model for view synthesis. EscherNet learns implicit and generative 3D representations coupled with the camera positional encoding (CaPE), allowing precise and continuous relative control of the camera transformation between an arbitrary number of reference and target views.


## Install ``` conda env create -f environment.yml -n eschernet conda activate eschernet ``` ## Demo Run demo to generate randomly sampled 25 novel views from (1,2,3,5,10) reference views: ```commandline bash eval_eschernet.sh ``` ## Camera Positional Encoding (CaPE) CaPE is applied in self/cross-attention for encoding camera pose info into transformers. The main modification is in `diffusers/models/attention_processor.py`. To quickly check the implementation of CaPE (6DoF and 4DoF), run: ``` python CaPE.py ``` ## Training ### Objaverse 1.0 Dataset Download Zero123's Objaverse Rendering data: ```commandline wget https://tri-ml-public.s3.amazonaws.com/datasets/views_release.tar.gz ``` Filter Zero-1-to-3 rendered views (empty images): ```commandline cd scripts python objaverse_filter.py --path /data/objaverse/views_release ``` ### Launch training Configure accelerator (8 A100 GPUs, bf16): ```commandline accelerate config ``` Choose 4DoF or 6DoF CaPE (Camera Positional Encoding): ```commandline cd 4DoF or 6DoF ``` Launch training: ```commandline accelerate launch train_eschernet.py --train_data_dir /data/objectverse/views_release --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 --train_batch_size 256 --dataloader_num_workers 16 --mixed_precision bf16 --gradient_checkpointing --T_in 3 --T_out 3 --T_in_val 10 --output_dir logs_N3M3B256_SD1.5 --push_to_hub --hub_model_id ***** --hub_token hf_******************* --tracker_project_name eschernet ``` For monitoring training progress, we recommand [wandb](https://wandb.ai/site) for its simplicity and powerful features. ```commandline wandb login ``` Offline mode: ```commandline WANDB_MODE=offline python xxx.py ``` ## Evaluation We provide [raw results](https://huggingface.co/datasets/kxic/EscherNet-Results) and two checkpoints [4DoF](https://huggingface.co/kxic/eschernet-4dof) and [6DoF](https://huggingface.co/kxic/eschernet-6dof) for easier comparison. ### Datasets ##### [GSO Google Scanned Objects](https://app.gazebosim.org/GoogleResearch/fuel/collections/Scanned%20Objects%20by%20Google%20Research) [GSO30](https://huggingface.co/datasets/kxic/EscherNet-Dataset/tree/main): We select 30 objects from GSO dataset and render 25 randomly sampled novel views for each object for both NVS and 3D reconstruction evaluation. ##### [RTMV](https://huggingface.co/datasets/kxic/EscherNet-Dataset/tree/main) We use the 10 scenes from [google_scanned.tar](https://drive.google.com/drive/folders/1cUXxUp6g25WwzHnm_491zNJJ4T7R_fum) under folder `40_scenes` for NVS evaluation. ##### [NeRF_Synthetic](https://huggingface.co/datasets/kxic/EscherNet-Dataset/tree/main) We use the all 8 NeRF objects for 2D NVS evaluation. ##### [Franka16](https://huggingface.co/datasets/kxic/EscherNet-Dataset/tree/main) We collected 16 real world object-centric recordings using a Franka Emika Panda robot arm with RealSense D435i Camera for real world NVS evaluation. ##### [Text2Img](https://huggingface.co/datasets/kxic/EscherNet-Dataset/tree/main) We collected Text2Img generation results from internet, [Stable Diffusion XL](https://github.com/Stability-AI/generative-models) (1 view) and [MVDream](https://github.com/bytedance/MVDream) (4 views: front, right, back, left) for NVS evaluation. ### Novel View Synthesis (NVS) To get 2D Novel View Synthesis (NVS) results, set `cape_type, checkpoint, data_type, data_dir` and run: ```commandline bash ./eval_eschernet.sh ``` Evaluate 2D metrics (PSNR, SSIM, LPIPS): ```commandline cd metrics python eval_2D_NVS.py ``` ### 3D Reconstruction We firstly generate 36 novel views with `data_type=GSO3D` by: ```commandline bash ./eval_eschernet.sh ``` Then we adopt [NeuS](https://github.com/Totoro97/NeuS) for 3D reconstruction: ```commandline export CUDA_HOME=/usr/local/cuda-11.8 pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch cd 3drecon python run_NeuS.py ``` Evaluate 3D metrics (Chamfer Distance, IoU): ```commandline cd metrics python eval_3D_GSO.py ``` ## Gradio Demo TODO. To build locally: ```commandline python gradio_eschernet.py ``` ## Acknowledgement We have intensively borrow codes from the following repositories. Many thanks to the authors for sharing their codes. - [Zero-1-to-3](https://github.com/cvlab-columbia/zero123) - [SyncDreamer](https://github.com/liuyuan-pal/SyncDreamer) - [MVDream](https://github.com/bytedance/MVDream) - [NeuS](https://github.com/Totoro97/NeuS) ## Citation If you find this work useful, a citation will be appreciated via: ``` @article{kong2024eschernet, title={EscherNet: A Generative Model for Scalable View Synthesis}, author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J}, journal={arXiv preprint arXiv:2402.03908}, year={2024} } ``` ================================================ FILE: environment.yml ================================================ name: eschernet channels: - xformers - pytorch - nvidia - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu - blas=1.0=mkl - brotlipy=0.7.0=py39h27cfd23_1003 - bzip2=1.0.8=h7b6447c_0 - ca-certificates=2023.05.30=h06a4308_0 - certifi=2023.7.22=py39h06a4308_0 - cffi=1.15.1=py39h5eee18b_3 - charset-normalizer=2.0.4=pyhd3eb1b0_0 - cryptography=41.0.2=py39h22a60cf_0 - cuda-cudart=11.8.89=0 - cuda-cupti=11.8.87=0 - cuda-libraries=11.8.0=0 - cuda-nvrtc=11.8.89=0 - cuda-nvtx=11.8.86=0 - cuda-runtime=11.8.0=0 - ffmpeg=4.3=hf484d3e_0 - filelock=3.9.0=py39h06a4308_0 - freetype=2.12.1=h4a9f257_0 - giflib=5.2.1=h5eee18b_3 - gmp=6.2.1=h295c915_3 - gmpy2=2.1.2=py39heeb90bb_0 - gnutls=3.6.15=he1e5248_0 - idna=3.4=py39h06a4308_0 - intel-openmp=2023.1.0=hdb19cb5_46305 - jinja2=3.1.2=py39h06a4308_0 - jpeg=9e=h5eee18b_1 - lame=3.100=h7b6447c_0 - lcms2=2.12=h3be6417_0 - ld_impl_linux-64=2.38=h1181459_1 - lerc=3.0=h295c915_0 - libcublas=11.11.3.6=0 - libcufft=10.9.0.58=0 - libcufile=1.7.0.149=0 - libcurand=10.3.3.53=0 - libcusolver=11.4.1.48=0 - libcusparse=11.7.5.86=0 - libdeflate=1.17=h5eee18b_0 - libffi=3.4.4=h6a678d5_0 - libgcc-ng=11.2.0=h1234567_1 - libgomp=11.2.0=h1234567_1 - libiconv=1.16=h7f8727e_2 - libidn2=2.3.4=h5eee18b_0 - libnpp=11.8.0.86=0 - libnvjpeg=11.9.0.86=0 - libpng=1.6.39=h5eee18b_0 - libstdcxx-ng=11.2.0=h1234567_1 - libtasn1=4.19.0=h5eee18b_0 - libtiff=4.5.0=h6a678d5_2 - libunistring=0.9.10=h27cfd23_0 - libwebp=1.2.4=h11a3e52_1 - libwebp-base=1.2.4=h5eee18b_1 - lz4-c=1.9.4=h6a678d5_0 - markupsafe=2.1.1=py39h7f8727e_0 - mkl=2023.1.0=h6d00ec8_46342 - mkl-service=2.4.0=py39h5eee18b_1 - mkl_fft=1.3.6=py39h417a72b_1 - mkl_random=1.2.2=py39h417a72b_1 - mpc=1.1.0=h10f8cd9_1 - mpfr=4.0.2=hb69a4c5_1 - mpmath=1.2.1=py39h06a4308_0 - ncurses=6.4=h6a678d5_0 - nettle=3.7.3=hbbd107a_1 - networkx=2.8.4=py39h06a4308_1 - numpy-base=1.25.0=py39hb5e798b_0 - openh264=2.1.1=h4ff587b_0 - openssl=3.0.9=h7f8727e_0 - pip=23.2.1=py39h06a4308_0 - pycparser=2.21=pyhd3eb1b0_0 - pyopenssl=23.2.0=py39h06a4308_0 - pysocks=1.7.1=py39h06a4308_0 - python=3.9.17=h955ad1f_0 - pytorch=2.0.0=py3.9_cuda11.8_cudnn8.7.0_0 - pytorch-cuda=11.8=h7e8668a_5 - pytorch-mutex=1.0=cuda - readline=8.2=h5eee18b_0 - requests=2.31.0=py39h06a4308_0 - setuptools=68.0.0=py39h06a4308_0 - sqlite=3.41.2=h5eee18b_0 - sympy=1.11.1=py39h06a4308_0 - tbb=2021.8.0=hdb19cb5_0 - tk=8.6.12=h1ccaba5_0 - torchaudio=2.0.0=py39_cu118 - torchtriton=2.0.0=py39 - torchvision=0.15.0=py39_cu118 - typing_extensions=4.7.1=py39h06a4308_0 - urllib3=1.26.16=py39h06a4308_0 - wheel=0.38.4=py39h06a4308_0 - xformers=0.0.19=py39_cu11.8.0_pyt2.0.0 - xz=5.4.2=h5eee18b_0 - zlib=1.2.13=h5eee18b_0 - zstd=1.5.5=hc292b87_0 - pip: - absl-py==1.4.0 - accelerate==0.23.0 - addict==2.4.0 - aiofiles==23.1.0 - aiohttp==3.8.5 - aiosignal==1.3.1 - albumentations==0.4.3 - altair==4.2.2 - annotated-types==0.5.0 - ansi2html==1.8.0 - antlr4-python3-runtime==4.9.3 - anyio==3.7.1 - appdirs==1.4.4 - asttokens==2.2.1 - async-timeout==4.0.2 - asyncer==0.0.2 - attrs==23.1.0 - backcall==0.2.0 - blinker==1.6.2 - braceexpand==0.1.7 - cachetools==5.3.1 - carvekit-colab==4.1.0 - click==8.1.6 - colorama==0.4.6 - coloredlogs==15.0.1 - comm==0.1.4 - configargparse==1.7 - contourpy==1.1.0 - controlnet-aux==0.0.7 - cycler==0.11.0 - dash==2.12.1 - dash-core-components==2.0.0 - dash-html-components==2.0.0 - dash-table==5.0.0 - dataclasses-json==0.6.1 - datasets==2.4.0 - dearpygui==1.10.1 - decorator==4.4.2 - deprecated==1.2.14 - diffusers==0.19.3 - dill==0.3.5.1 - docker-pycreds==0.4.0 - easydict==1.10 - einops==0.3.0 - entrypoints==0.4 - envlight==0.1.0 - exceptiongroup==1.1.2 - executing==2.0.0 - fastapi==0.100.1 - fastcore==1.5.29 - fastjsonschema==2.18.0 - ffmpy==0.3.1 - filetype==1.2.0 - fire==0.4.0 - flask==2.2.5 - flatbuffers==23.5.26 - fonttools==4.41.1 - freetype-py==2.4.0 - frozenlist==1.4.0 - fsspec==2023.6.0 - ftfy==6.1.1 - future==0.18.3 - gitdb==4.0.10 - gitpython==3.1.32 - glfw==2.6.2 - google-auth==2.22.0 - google-auth-oauthlib==1.0.0 - gradio==3.21.0 - grpcio==1.56.2 - h11==0.14.0 - h5py==3.10.0 - hjson==3.1.0 - httpcore==0.17.3 - httpx==0.24.1 - huggingface-hub==0.16.4 - humanfriendly==10.0 - imageio==2.32.0 - imageio-ffmpeg==0.4.8 - imgaug==0.2.6 - imgviz==1.7.3 - importlib-metadata==6.8.0 - importlib-resources==6.0.0 - iniconfig==2.0.0 - inquirerpy==0.3.4 - ipython==8.14.0 - ipywidgets==8.1.0 - itsdangerous==2.1.2 - jaxtyping==0.2.23 - jedi==0.18.2 - joblib==1.3.2 - jsonpatch==1.33 - jsonpointer==2.4 - jsonschema==4.18.4 - jsonschema-specifications==2023.7.1 - jupyter-core==5.3.1 - jupyterlab-widgets==3.0.8 - kiwisolver==1.4.4 - kornia==0.6.0 - lazy-loader==0.3 - libigl==2.4.1 - lightning-utilities==0.9.0 - linkify-it-py==2.0.2 - llvmlite==0.41.1 - loguru==0.7.0 - lovely-numpy==0.2.9 - lovely-tensors==0.1.15 - lpips==0.1.4 - markdown==3.4.4 - markdown-it-py==2.2.0 - marshmallow==3.20.1 - matplotlib==3.7.2 - matplotlib-inline==0.1.6 - mdit-py-plugins==0.3.3 - mdurl==0.1.2 - mesh2sdf==1.1.0 - modelcards==0.1.6 - moviepy==1.0.3 - multidict==6.0.4 - multiprocess==0.70.13 - mypy-extensions==1.0.0 - nbformat==5.5.0 - nerfacc==0.5.2 - nest-asyncio==1.5.7 - ninja==1.11.1 - numba==0.58.1 - numpy==1.24.4 - oauthlib==3.2.2 - objprint==0.2.3 - omegaconf==2.3.0 - onnxruntime==1.16.1 - onnxruntime-gpu==1.16.1 - open-clip-torch==2.7.0 - open3d==0.16.0 - opencv-python==4.5.5.64 - opencv-python-headless==4.8.0.74 - orjson==3.9.2 - packaging==23.1 - pandas==2.0.3 - parso==0.8.3 - pathtools==0.1.2 - pexpect==4.8.0 - pfzy==0.3.4 - pickleshare==0.7.5 - pillow==10.0.0 - piq==0.8.0 - platformdirs==3.10.0 - plotly==5.13.1 - pluggy==1.3.0 - plyfile==1.0.1 - pooch==1.7.0 - proglog==0.1.10 - prompt-toolkit==3.0.39 - protobuf==3.20.3 - psutil==5.9.5 - ptyprocess==0.7.0 - pudb==2019.2 - pure-eval==0.2.2 - py-cpuinfo==9.0.0 - pyarrow==12.0.1 - pyasn1==0.5.0 - pyasn1-modules==0.3.0 - pybind11==2.11.1 - pydantic==1.10.12 - pydantic-core==2.4.0 - pydeck==0.8.0 - pydeprecate==0.3.1 - pydub==0.25.1 - pyglet==1.5.0 - pygltflib==1.16.1 - pygments==2.15.1 - pyhocon==0.3.57 - pymatting==1.1.10 - pymcubes==0.1.2 - pymeshlab==2022.2.post4 - pympler==1.0.1 - pyopengl==3.1.0 - pyparsing==3.0.9 - pypose==0.6.5 - pyquaternion==0.9.9 - pyransac3d==0.6.0 - pyrender==0.1.45 - pytest==7.4.3 - python-dateutil==2.8.2 - python-multipart==0.0.6 - pytorch-fid==0.3.0 - pytorch-ignite==0.4.12 - pytorch-lightning==1.9.0 - pytz==2023.3 - pytz-deprecation-shim==0.1.0.post0 - pywavelets==1.4.1 - pyyaml==6.0.1 - referencing==0.30.0 - regex==2023.6.3 - rembg==2.0.50 - requests-oauthlib==1.3.1 - responses==0.18.0 - retrying==1.3.4 - rich==13.4.2 - rpds-py==0.9.2 - rsa==4.9 - safetensors==0.3.1 - scikit-image==0.20.0 - scikit-learn==1.3.0 - scipy==1.9.1 - segment-anything==1.0 - sentencepiece==0.1.99 - sentry-sdk==1.28.1 - setproctitle==1.3.2 - six==1.16.0 - smmap==5.0.0 - sniffio==1.3.0 - stack-data==0.6.2 - starlette==0.27.0 - streamlit==1.22.0 - style==1.1.0 - taming-transformers==0.0.1 - taming-transformers-rom1504==0.0.6 - tenacity==8.2.2 - tensorboard==2.13.0 - tensorboard-data-server==0.7.1 - termcolor==2.3.0 - test-tube==0.7.5 - threadpoolctl==3.2.0 - tifffile==2023.7.18 - timm==0.9.10 - tokenizers==0.13.3 - toml==0.10.2 - tomli==2.0.1 - toolz==0.12.0 - torch-efficient-distloss==0.1.3 - torch-fidelity==0.3.0 - torchmetrics==1.2.0 - tornado==6.3.2 - tqdm==4.65.0 - traitlets==5.9.0 - transformers==4.30.2 - trimesh==3.23.5 - typeguard==2.13.3 - typing-inspect==0.9.0 - tzdata==2023.3 - tzlocal==4.3.1 - uc-micro-py==1.0.2 - update==0.0.1 - urwid==2.1.2 - uvicorn==0.23.1 - validators==0.20.0 - varname==0.12.0 - visdom==0.2.4 - wandb==0.15.7 - watchdog==3.0.0 - wcwidth==0.2.6 - webdataset==0.2.5 - websocket-client==1.6.4 - websockets==11.0.3 - werkzeug==2.2.3 - widgetsnbextension==4.0.8 - wrapt==1.15.0 - xatlas==0.0.8 - xxhash==3.2.0 - yarl==1.9.2 - zipp==3.16.2 ================================================ FILE: eval_eschernet.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import argparse import os import einops import numpy as np import torch import torch.utils.checkpoint from accelerate.utils import ProjectConfiguration, set_seed from PIL import Image from torchvision import transforms from tqdm.auto import tqdm import torchvision import json import cv2 from skimage.io import imsave import matplotlib.pyplot as plt # read .exr files for RTMV dataset os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a Zero123 training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default="lambdalabs/sd-image-variations-diffusers", required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help=( "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" " float32 precision." ), ) parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=256, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--T_in", type=int, default=1, help="Number of input views" ) parser.add_argument( "--T_out", type=int, default=1, help="Number of output views" ) parser.add_argument( "--guidance_scale", type=float, default=3.0, help="unconditional guidance scale, if guidance_scale>1.0, do_classifier_free_guidance" ) parser.add_argument( "--data_dir", type=str, default=".", help=( "The input data dir. Should contain the .png files (or other data files) for the task." ), ) parser.add_argument( "--data_type", type=str, default="GSO25", help=( "The input data type. Chosen from GSO25, GSO3D, GSO100, RTMV, NeRF, Franka, MVDream, Text2Img" ), ) parser.add_argument( "--cape_type", type=str, default="6DoF", help=( "The camera pose encoding CaPE type. Chosen from 4DoF, 6DoF" ), ) parser.add_argument( "--output_dir", type=str, default="logs_eval", help=( "The output directory where the model predictions and checkpoints will be written." ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--enable_xformers_memory_efficient_attention", default=True, help="Whether or not to use xformers." ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() if args.resolution % 8 != 0: raise ValueError( "`--resolution` must be divisible by 8 for consistently sized encoded images." ) return args # create angles in archimedean spiral with T_out number import math def get_archimedean_spiral(sphere_radius, num_steps=250): # x-z plane, around upper y ''' https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi ''' a = 40 r = sphere_radius translations = [] angles = [] # i = a / 2 i = 0.01 while i < a: theta = i / a * math.pi x = r * math.sin(theta) * math.cos(-i) z = r * math.sin(-theta + math.pi) * math.sin(-i) y = r * - math.cos(theta) # translations.append((x, y, z)) # origin translations.append((x, z, -y)) angles.append([np.rad2deg(-i), np.rad2deg(theta)]) # i += a / (2 * num_steps) i += a / (1 * num_steps) return np.array(translations), np.stack(angles) def look_at(origin, target, up): forward = (target - origin) forward = forward / np.linalg.norm(forward) right = np.cross(up, forward) right = right / np.linalg.norm(right) new_up = np.cross(forward, right) rotation_matrix = np.column_stack((right, new_up, -forward, target)) matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1])) return matrix def main(args): if args.seed is not None: set_seed(args.seed) CaPE_TYPE = args.cape_type if CaPE_TYPE == "6DoF": import sys sys.path.insert(0, "./6DoF/") # use the customized diffusers modules from diffusers import DDIMScheduler from dataset import get_pose from CN_encoder import CN_encoder from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline elif CaPE_TYPE == "4DoF": import sys sys.path.insert(0, "./4DoF/") # use the customized diffusers modules from diffusers import DDIMScheduler from dataset import get_pose from CN_encoder import CN_encoder from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline else: raise ValueError("CaPE_TYPE must be chosen from 4DoF, 6DoF") # from dataset import get_pose # from CN_encoder import CN_encoder # from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline DATA_DIR = args.data_dir DATA_TYPE = args.data_type if DATA_TYPE == "GSO25": T_in_DATA_TYPE = "render_mvs_25" # same condition for GSO T_out_DATA_TYPE = "render_mvs_25" # for 2D metrics T_out = 25 elif DATA_TYPE == "GSO3D": T_in_DATA_TYPE = "render_mvs_25" # same condition for GSO T_out_DATA_TYPE = "render_sync_36_single" # for 3D metrics T_out = 36 elif DATA_TYPE == "GSO100": T_in_DATA_TYPE = "render_mvs_25" # same condition for GSO T_out_DATA_TYPE = "render_spiral_100" # for 360 gif T_out = 100 elif DATA_TYPE == "NeRF": T_out = 200 elif DATA_TYPE == "RTMV": T_out = 20 elif DATA_TYPE == "Franka": T_out = 100 # do a 360 gif elif DATA_TYPE == "MVDream": T_out = 100 # do a 360 gif elif DATA_TYPE == "Text2Img": T_out = 100 # do a 360 gif else: raise NotImplementedError T_in = args.T_in OUTPUT_DIR= f"logs_{CaPE_TYPE}/{DATA_TYPE}/N{T_in}M{T_out}" os.makedirs(OUTPUT_DIR, exist_ok=True) # get all folders in DATA_DIR if DATA_TYPE == "Text2Img": # get all rgba_png in DATA_DIR obj_names = [f for f in os.listdir(DATA_DIR) if f.endswith('rgba.png')] else: obj_names = [f for f in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, f))] weight_dtype = torch.float16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") h, w = args.resolution, args.resolution bg_color = [1., 1., 1., 1.] radius = 1.5 #1.8 # Objaverse training radius [1.5, 2.2] # radius_4dof = np.pi * (np.log(radius) - np.log(1.5)) / (np.log(2.2)-np.log(1.5)) # Init Dataset image_transforms = torchvision.transforms.Compose( [ torchvision.transforms.Resize((args.resolution, args.resolution)), # 256, 256 transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ] ) # Init pipeline scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision) image_encoder = CN_encoder.from_pretrained(args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision) pipeline = Zero1to3StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, scheduler=scheduler, image_encoder=None, safety_checker=None, feature_extractor=None, torch_dtype=weight_dtype, ) pipeline.image_encoder = image_encoder pipeline = pipeline.to(device) pipeline.set_progress_bar_config(disable=False) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() # enable vae slicing pipeline.enable_vae_slicing() if args.seed is None: generator = None else: generator = torch.Generator(device=device).manual_seed(args.seed) for obj_name in tqdm(obj_names): print(f"Processing {obj_name}") if DATA_TYPE == "NeRF": if os.path.exists(os.path.join(args.output_dir, obj_name, "output.gif")): continue # load train info with open(os.path.join(DATA_DIR, obj_name, "transforms_train.json"), "r") as f: train_info = json.load(f)["frames"] # load test info with open(os.path.join(DATA_DIR, obj_name, "transforms_test.json"), "r") as f: test_info = json.load(f)["frames"] # find the radius [min_t, max_t] of the object, we later scale it to training radius [1.5, 2.2] max_t = 0 min_t = 100 for i in range(len(train_info)): pose = np.array(train_info[i]["transform_matrix"]).reshape(4, 4) translation = pose[:3, -1] radii = np.linalg.norm(translation) if max_t < radii: max_t = radii if min_t > radii: min_t = radii info_dir = os.path.join("metrics/NeRF_idx", obj_name) assert os.path.exists(info_dir) # use fixed train index train_index = np.load(os.path.join(info_dir, f"train_N{T_in}M20_random.npy")) test_index = np.arange(len(test_info)) # use all test views elif DATA_TYPE == "Franka": angles_in = np.load(os.path.join(DATA_DIR, obj_name, "angles.npy")) # azimuth, elevation in radians assert T_in <= len(angles_in) total_index = np.arange(0, len(angles_in)) # num of input views # random shuffle total_index np.random.shuffle(total_index) train_index = total_index[:T_in] xyzs, angles_out = get_archimedean_spiral(radius, T_out) origin = np.array([0, 0, 0]) up = np.array([0, 0, 1]) test_index = np.arange(len(angles_out)) # use all 100 test views elif DATA_TYPE == "MVDream": # 4 input views front right back left angles_in = [] for polar in [90]: # 1 for azimu in np.arange(0, 360, 90): # 4 angles_in.append(np.array([azimu, polar])) assert T_in == len(angles_in) xyzs, angles_out = get_archimedean_spiral(radius, T_out) origin = np.array([0, 0, 0]) up = np.array([0, 0, 1]) train_index = np.arange(T_in) test_index = np.arange(T_out) elif DATA_TYPE == "Text2Img": # 1 input view angles_in = [] angles_in.append(np.array([0, 90])) assert T_in == len(angles_in) xyzs, angles_out = get_archimedean_spiral(radius, T_out) origin = np.array([0, 0, 0]) up = np.array([0, 0, 1]) train_index = np.arange(T_in) test_index = np.arange(T_out) else: train_index = np.arange(T_in) test_index = np.arange(T_out) # prepare input img + pose, output pose input_image = [] pose_in = [] pose_out = [] gt_image = [] for T_in_index in train_index: if DATA_TYPE == "RTMV": img_path = os.path.join(DATA_DIR, obj_name, '%05d.exr' % T_in_index) input_im = cv2.imread(img_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) img = cv2.cvtColor(input_im, cv2.COLOR_BGR2RGB, input_im) img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") input_image.append(image_transforms(img)) # load input pose pose_path = os.path.join(DATA_DIR, obj_name, '%05d.json' % T_in_index) with open(pose_path, "r") as f: pose_dict = json.load(f) input_RT = np.array(pose_dict["camera_data"]["cam2world"]).T input_RT = np.linalg.inv(input_RT)[:3] pose_in.append(get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) else: if DATA_TYPE == "NeRF": img_path = os.path.join(DATA_DIR, obj_name, train_info[T_in_index]["file_path"] + ".png") pose = np.array(train_info[T_in_index]["transform_matrix"]) if CaPE_TYPE == "6DoF": # blender to opencv pose[1:3, :] *= -1 pose = np.linalg.inv(pose) # scale radius to [1.5, 2.2] pose[:3, 3] *= 1. / max_t * radius elif CaPE_TYPE == "4DoF": pose = np.linalg.inv(pose) pose_in.append(torch.from_numpy(get_pose(pose))) elif DATA_TYPE == "Franka": img_path = os.path.join(DATA_DIR, obj_name, "images_rgba", f"frame{T_in_index:06d}.png") azimuth, elevation = np.rad2deg(angles_in[T_in_index]) print("input angles index", T_in_index, "azimuth", azimuth, "elevation", elevation) if CaPE_TYPE == "4DoF": pose_in.append(torch.from_numpy([np.deg2rad(90. - elevation), np.deg2rad(azimuth - 180), 0., 0.])) elif CaPE_TYPE == "6DoF": neg_i = np.deg2rad(azimuth - 180) neg_theta = np.deg2rad(90. - elevation) xyz = np.array([np.sin(neg_theta) * np.cos(neg_i), np.sin(-neg_theta + np.pi) * np.sin(neg_i), np.cos(neg_theta)]) * radius pose = look_at(origin, xyz, up) pose = np.linalg.inv(pose) pose[2, :] *= -1 pose_in.append(torch.from_numpy(get_pose(pose))) elif DATA_TYPE == "MVDream" or DATA_TYPE == "Text2Img": if DATA_TYPE == "MVDream": img_path = os.path.join(DATA_DIR, obj_name, f"{T_in_index}_rgba.png") elif DATA_TYPE == "Text2Img": img_path = os.path.join(DATA_DIR, obj_name) azimuth, polar = angles_in[T_in_index] if CaPE_TYPE == "4DoF": pose_in.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.])) elif CaPE_TYPE == "6DoF": neg_theta = np.deg2rad(polar) neg_i = np.deg2rad(azimuth) xyz = np.array([np.sin(neg_theta) * np.cos(neg_i), np.sin(-neg_theta + np.pi) * np.sin(neg_i), np.cos(neg_theta)]) * radius pose = look_at(origin, xyz, up) pose = np.linalg.inv(pose) pose[2, :] *= -1 pose_in.append(torch.from_numpy(get_pose(pose))) else: # GSO img_path = os.path.join(DATA_DIR, obj_name, T_in_DATA_TYPE, "model/%03d.png" % T_in_index) pose_path = os.path.join(DATA_DIR, obj_name, T_in_DATA_TYPE, "model/%03d.npy" % T_in_index) pose_in.append(get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) # load image img = plt.imread(img_path) img[img[:, :, -1] == 0.] = bg_color img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") input_image.append(image_transforms(img)) for T_out_index in test_index: if DATA_TYPE == "RTMV": img_path = os.path.join(DATA_DIR, obj_name, '%05d.exr' % T_out_index) gt_im = cv2.imread(img_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) img = cv2.cvtColor(gt_im, cv2.COLOR_BGR2RGB, gt_im) img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") gt_image.append(image_transforms(img)) # load pose pose_path = os.path.join(DATA_DIR, obj_name, '%05d.json' % T_out_index) with open(pose_path, "r") as f: pose_dict = json.load(f) output_RT = np.array(pose_dict["camera_data"]["cam2world"]).T output_RT = np.linalg.inv(output_RT)[:3] pose_out.append(get_pose(np.concatenate([output_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) else: if DATA_TYPE == "NeRF": img_path = os.path.join(DATA_DIR, obj_name, test_info[T_out_index]["file_path"] + ".png") pose = np.array(test_info[T_out_index]["transform_matrix"]) if CaPE_TYPE == "6DoF": # blender to opencv pose[1:3, :] *= -1 pose = np.linalg.inv(pose) # scale radius to [1.5, 2.2] pose[:3, 3] *= 1. / max_t * radius elif CaPE_TYPE == "4DoF": pose = np.linalg.inv(pose) pose_out.append(torch.from_numpy(get_pose(pose))) elif DATA_TYPE == "Franka": img_path = None azimuth, polar = angles_out[T_out_index] if CaPE_TYPE == "4DoF": pose_out.append(torch.from_numpy([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.])) elif CaPE_TYPE == "6DoF": pose = look_at(origin, xyzs[T_out_index], up) neg_theta = np.deg2rad(polar) neg_i = np.deg2rad(azimuth) xyz = np.array([np.sin(neg_theta) * np.cos(neg_i), np.sin(-neg_theta + np.pi) * np.sin(neg_i), np.cos(neg_theta)]) * radius assert np.allclose(xyzs[T_out_index], xyz) pose = np.linalg.inv(pose) pose[2, :] *= -1 pose_out.append(torch.from_numpy(get_pose(pose))) elif DATA_TYPE == "MVDream" or DATA_TYPE == "Text2Img": img_path = None azimuth, polar = angles_out[T_out_index] if CaPE_TYPE == "4DoF": pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.])) elif CaPE_TYPE == "6DoF": pose = look_at(origin, xyzs[T_out_index], up) pose = np.linalg.inv(pose) pose[2, :] *= -1 pose_out.append(torch.from_numpy(get_pose(pose))) else: # GSO img_path = os.path.join(DATA_DIR, obj_name, T_out_DATA_TYPE, "model/%03d.png" % T_out_index) pose_path = os.path.join(DATA_DIR, obj_name, T_out_DATA_TYPE, "model/%03d.npy" % T_out_index) if T_out_DATA_TYPE == "render_mvs_25": # blender coordinate pose_out.append(get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) else: # opencv coordinate pose = get_pose(np.concatenate([np.load(pose_path)[:3, :], np.array([[0, 0, 0, 1]])], axis=0)) pose[1:3, :] *= -1 # pose out 36 is in opencv coordinate, pose in 25 is in blender coordinate pose_out.append(torch.from_numpy(pose)) # load image if img_path is not None: # sometimes don't have GT target view image img = plt.imread(img_path) img[img[:, :, -1] == 0.] = bg_color img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") gt_image.append(image_transforms(img)) # [B, T, C, H, W] input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0) if len(gt_image)>0: gt_image = torch.stack(gt_image, dim=0).to(device).to(weight_dtype).unsqueeze(0) # [B, T, 4] pose_in = np.stack(pose_in) pose_out = np.stack(pose_out) if CaPE_TYPE == "6DoF": pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1]) pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1]) pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0) pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0) pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0) pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0) input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w") if len(gt_image)>0: gt_image = einops.rearrange(gt_image, "b t c h w -> (b t) c h w") assert T_in == input_image.shape[0] assert T_in == pose_in.shape[1] assert T_out == pose_out.shape[1] # run inference if CaPE_TYPE == "6DoF": with torch.autocast("cuda"): image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]], height=h, width=w, T_in=T_in, T_out=T_out, guidance_scale=args.guidance_scale, num_inference_steps=50, generator=generator, output_type="numpy").images elif CaPE_TYPE == "4DoF": with torch.autocast("cuda"): image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in], height=h, width=w, T_in=T_in, T_out=T_out, guidance_scale=args.guidance_scale, num_inference_steps=50, generator=generator, output_type="numpy").images # save results output_dir = os.path.join(OUTPUT_DIR, obj_name) os.makedirs(output_dir, exist_ok=True) # save input image for visualization imsave(os.path.join(output_dir, 'input.png'), ((np.concatenate(input_image.permute(0, 2, 3, 1).cpu().numpy(), 1) + 1) / 2 * 255).astype(np.uint8)) # save output image if T_out >= 100: # save to N imgs for i in range(T_out): imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8)) # make a gif frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)] frame_one = frames[0] frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames, save_all=True, duration=50, loop=1) else: imsave(os.path.join(output_dir, '0.png'), (np.concatenate(image, 1) * 255).astype(np.uint8)) # save gt for visualization if len(gt_image)>0: imsave(os.path.join(output_dir, 'gt.png'), ((np.concatenate(gt_image.permute(0, 2, 3, 1).cpu().numpy(), 1) + 1) / 2 * 255).astype(np.uint8)) if __name__ == "__main__": args = parse_args() main(args) ================================================ FILE: eval_eschernet.sh ================================================ #!/bin/bash # bash script to evaluate the model # TODO ################### Chose CaPE type ########################## # 6DoF cape_type="6DoF" pretrained_model="kxic/eschernet-6dof" ## 4DoF #cape_type="4DoF" #pretrained_model="kxic/eschernet-4dof" ################### Chose CaPE type ########################## # TODO ################### Chose data type ########################## # demo data_type="GSO25" T_ins=(1 2 3 5 10) data_dir="./demo/GSO30" ## GSO #data_type="GSO25" # GSO25, GSO3D, GSO100, NeRF, RTMV #T_ins=(1 2 3 5 10) #data_dir="/home/xin/data/EscherNet/Data/GSO30/" ## RTMV #data_type="RTMV" #T_ins=(1 2 3 5 10) #data_dir="/home/xin/data/RTMV/40_scenes/" ## NeRF #data_type="NeRF" #T_ins=(1 2 3 5 10 20 50 100) #data_dir="/home/xin/data/nerf/nerf_synthetic" ## Real World Franka Recordings #data_type="Franka" #T_ins=(5) #data_dir="/home/xin/data/EscherNet/Data/Franka16/" ## MVDream, 4 views to 100 #data_type='MVDream' #T_ins=(4) #data_dir="/home/xin/data/EscherNet/Data/MVDream/" ## Text2Img, 1 view to 100 #data_type='Text2Img' #T_ins=(1) #data_dir="/home/xin/data/EscherNet/Data/Text2Img/" ################### Chose data type ########################## # run for T_in in "${T_ins[@]}"; do python eval_eschernet.py --pretrained_model_name_or_path "$pretrained_model" \ --data_dir "$data_dir" \ --data_type "$data_type" \ --cape_type "$cape_type" \ --T_in "$T_in" done ================================================ FILE: metrics/RenderOption_phong.json ================================================ { "background_color" : [ 1.0, 1.0, 1.0 ], "class_name" : "RenderOption", "default_mesh_color" : [ 0.69999999999999996, 0.69999999999999996, 0.69999999999999996 ], "image_max_depth" : 3000, "image_stretch_option" : 1, "interpolation_option" : 0, "light0_color" : [ 1.0, 1.0, 1.0 ], "light0_diffuse_power" : 0.66000000000000003, "light0_position" : [ 0.0, 0.0, 2.0 ], "light0_specular_power" : 0.20000000000000001, "light0_specular_shininess" : 100.0, "light1_color" : [ 1.0, 1.0, 1.0 ], "light1_diffuse_power" : 0.66000000000000003, "light1_position" : [ 0.0, 0.0, 2.0 ], "light1_specular_power" : 0.20000000000000001, "light1_specular_shininess" : 100.0, "light2_color" : [ 1.0, 1.0, 1.0 ], "light2_diffuse_power" : 0.66000000000000003, "light2_position" : [ 0.0, 0.0, -2.0 ], "light2_specular_power" : 0.20000000000000001, "light2_specular_shininess" : 100.0, "light3_color" : [ 1.0, 1.0, 1.0 ], "light3_diffuse_power" : 0.66000000000000003, "light3_position" : [ 0.0, 0.0, -2.0 ], "light3_specular_power" : 0.20000000000000001, "light3_specular_shininess" : 100.0, "light_ambient_color" : [ 0.0, 0.0, 0.0 ], "light_on" : true, "line_width" : 1.0, "mesh_color_option" : 0, "mesh_shade_option" : 0, "mesh_show_back_face" : false, "mesh_show_wireframe" : false, "point_color_option" : 0, "point_show_normal" : false, "point_size" : 5.0, "show_coordinate_frame" : false, "version_major" : 1, "version_minor" : 0 } ================================================ FILE: metrics/RenderOption_rgb.json ================================================ { "background_color" : [ 1.0, 1.0, 1.0 ], "class_name" : "RenderOption", "default_mesh_color" : [ 0.69999999999999996, 0.69999999999999996, 0.69999999999999996 ], "image_max_depth" : 3000, "image_stretch_option" : 1, "interpolation_option" : 0, "light0_color" : [ 1.0, 1.0, 1.0 ], "light0_diffuse_power" : 0.66000000000000003, "light0_position" : [ 0.0, 0.0, 2.0 ], "light0_specular_power" : 0.20000000000000001, "light0_specular_shininess" : 100.0, "light1_color" : [ 1.0, 1.0, 1.0 ], "light1_diffuse_power" : 0.66000000000000003, "light1_position" : [ 0.0, 0.0, 2.0 ], "light1_specular_power" : 0.20000000000000001, "light1_specular_shininess" : 100.0, "light2_color" : [ 1.0, 1.0, 1.0 ], "light2_diffuse_power" : 0.66000000000000003, "light2_position" : [ 0.0, 0.0, -2.0 ], "light2_specular_power" : 0.20000000000000001, "light2_specular_shininess" : 100.0, "light3_color" : [ 1.0, 1.0, 1.0 ], "light3_diffuse_power" : 0.66000000000000003, "light3_position" : [ 0.0, 0.0, -2.0 ], "light3_specular_power" : 0.20000000000000001, "light3_specular_shininess" : 100.0, "light_ambient_color" : [ 0.0, 0.0, 0.0 ], "light_on" : false, "line_width" : 1.0, "mesh_color_option" : 1, "mesh_shade_option" : 0, "mesh_show_back_face" : false, "mesh_show_wireframe" : false, "point_color_option" : 0, "point_show_normal" : false, "point_size" : 5.0, "show_coordinate_frame" : false, "version_major" : 1, "version_minor" : 0 } ================================================ FILE: metrics/ScreenCamera_0.json ================================================ { "class_name" : "PinholeCameraParameters", "extrinsic" : [ 0.65779589839210661, 0.29716996766085974, -0.69209433343942717, 0.0, 0.742984256854886, -0.40680953545890097, 0.53148884835479482, 0.0, -0.12360805040252906, -0.86382637849289878, -0.48839025143375769, 0.0, 1.3877787807814457e-17, 0.0, 1.2124355652982142, 1.0 ], "intrinsic" : { "height" : 512, "intrinsic_matrix" : [ 443.40500673763262, 0.0, 0.0, 0.0, 443.40500673763262, 0.0, 255.5, 255.5, 1.0 ], "width" : 512 }, "version_major" : 1, "version_minor" : 0 } ================================================ FILE: metrics/ScreenCamera_1.json ================================================ { "class_name" : "PinholeCameraParameters", "extrinsic" : [ 0.4578931378650884, -0.46762792585253216, 0.75608068171255505, 0.0, -0.88729477065783835, -0.29315579979225609, 0.35604447336170914, 0.0, 0.055153098365525073, -0.83389675620607939, -0.54915784227274889, 0.0, 8.3266726846886741e-17, -1.1102230246251565e-16, 1.2124355652982142, 1.0 ], "intrinsic" : { "height" : 512, "intrinsic_matrix" : [ 443.40500673763262, 0.0, 0.0, 0.0, 443.40500673763262, 0.0, 255.5, 255.5, 1.0 ], "width" : 512 }, "version_major" : 1, "version_minor" : 0 } ================================================ FILE: metrics/ScreenCamera_2.json ================================================ { "class_name" : "PinholeCameraParameters", "extrinsic" : [ 0.38584211693687692, -0.53027159447263705, -0.75494231361946895, 0.0, 0.91280788326777862, 0.10073197705559719, 0.39577119279031425, 0.0, -0.13381938963344159, -0.84182249013311461, 0.52290273097796625, 0.0, -5.5511151231257827e-17, 0.0, 1.2124355652982139, 1.0 ], "intrinsic" : { "height" : 512, "intrinsic_matrix" : [ 443.40500673763262, 0.0, 0.0, 0.0, 443.40500673763262, 0.0, 255.5, 255.5, 1.0 ], "width" : 512 }, "version_major" : 1, "version_minor" : 0 } ================================================ FILE: metrics/ScreenCamera_3.json ================================================ { "class_name" : "PinholeCameraParameters", "extrinsic" : [ 0.38393168649424414, 0.52506956671681526, 0.7595382875230855, 0.0, -0.91049736873254372, 0.35211050739581706, 0.21682419632636629, 0.0, -0.15359362498749676, -0.77480329160973249, 0.61326083983401714, 0.0, 1.3877787807814457e-17, 1.1102230246251565e-16, 1.2124355652982139, 1.0 ], "intrinsic" : { "height" : 512, "intrinsic_matrix" : [ 443.40500673763262, 0.0, 0.0, 0.0, 443.40500673763262, 0.0, 255.5, 255.5, 1.0 ], "width" : 512 }, "version_major" : 1, "version_minor" : 0 } ================================================ FILE: metrics/ScreenCamera_4.json ================================================ { "class_name" : "PinholeCameraParameters", "extrinsic" : [ -0.50698836466064368, 0.024194721372433812, -0.86161326217534318, 0.0, 0.85983959109017072, 0.084150811377461041, -0.50358169003489717, 0.0, 0.060321436426982525, -0.9961592525305506, -0.063467061576539363, 0.0, -1.0148132334464322e-16, 0.0, 1.2124355652982139, 1.0 ], "intrinsic" : { "height" : 512, "intrinsic_matrix" : [ 443.40500673763262, 0.0, 0.0, 0.0, 443.40500673763262, 0.0, 255.5, 255.5, 1.0 ], "width" : 512 }, "version_major" : 1, "version_minor" : 0 } ================================================ FILE: metrics/ScreenCamera_5.json ================================================ { "class_name" : "PinholeCameraParameters", "extrinsic" : [ -0.66170438413797683, -0.38754287971255974, 0.64183940701446296, 0.0, -0.73929302978246758, 0.47982348676658376, -0.47245659870745566, 0.0, -0.12487243141558757, -0.78713400252514476, -0.60400905451913967, 0.0, 5.5511151231257827e-17, 1.1102230246251565e-16, 1.2124355652982142, 1.0 ], "intrinsic" : { "height" : 512, "intrinsic_matrix" : [ 443.40500673763262, 0.0, 0.0, 0.0, 443.40500673763262, 0.0, 255.5, 255.5, 1.0 ], "width" : 512 }, "version_major" : 1, "version_minor" : 0 } ================================================ FILE: metrics/eval_2D_NVS.py ================================================ import os import numpy as np from matplotlib import pyplot as plt from PIL import Image import torch import torch.nn.functional as F import cv2 from tqdm import tqdm import json from skimage.metrics import structural_similarity as calculate_ssim import lpips LPIPS = lpips.LPIPS(net='alex', version='0.1') def calc_2D_metrics(pred_np, gt_np): # pred_np: [H, W, 3], [0, 255], np.uint8 pred_image = torch.from_numpy(pred_np).unsqueeze(0).permute(0, 3, 1, 2) gt_image = torch.from_numpy(gt_np).unsqueeze(0).permute(0, 3, 1, 2) # [0-255] -> [-1, 1] pred_image = pred_image.float() / 127.5 - 1 gt_image = gt_image.float() / 127.5 - 1 # for 1 image # pixel loss loss = F.mse_loss(pred_image[0], gt_image[0].cpu()).item() # LPIPS lpips = LPIPS(pred_image[0], gt_image[0].cpu()).item() # [-1, 1] torch tensor # SSIM ssim = calculate_ssim(pred_np, gt_np, channel_axis=2) # PSNR psnr = cv2.PSNR(gt_np, pred_np) return loss, lpips, ssim, psnr # script to evaluate the model on GSO/RTMV/NeRF dataset # todo modify the path LOG_DIR = "../logs_4DoF" # 6DoF, 4DoF DATASET = "NeRF" # GSO RTMV NeRF if DATASET == "GSO": DATA_DIR = "/home/xin/data/EscherNet/Data/GSO30/" T_ins = [1, 2, 3, 5, 10] total_views = 25 DATA_TYPE = "render_mvs_25" start_id = 10 # from 10 to total_views as test views elif DATASET == "RTMV": T_ins = [1, 2, 3, 5, 10] total_views = 20 start_id = 10 elif DATASET == "NeRF": DATA_DIR = "/home/xin/data/EscherNet/Data/nerf_synthetic/" T_ins = [1, 2, 3, 5, 10, 20, 50, 100] total_views = 200 start_id = 0 else: raise NotImplementedError for T_in in tqdm(T_ins): if DATASET == "GSO": log_dir = os.path.join(LOG_DIR, f"GSO25/N{T_in}M25") elif DATASET == "RTMV": log_dir = os.path.join(LOG_DIR, f"RTMV/N{T_in}M20") DATA_DIR = log_dir elif DATASET == "NeRF": log_dir = os.path.join(LOG_DIR, f"NeRF/N{T_in}M200") # calc 2D metrics val_lpips = 0 val_ssim = 0 val_psnr = 0 val_loss = 0 val_num = 0 # get all objects objects = [f for f in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, f))] for obj in objects: lpips = 0 ssim = 0 psnr = 0 loss = 0 num = 0 color = [1., 1., 1., 1.] if DATASET == "NeRF": # load test info with open(os.path.join(DATA_DIR, obj, "transforms_test.json"), "r") as f: test_info = json.load(f)["frames"] total_views = len(test_info) for i in range(start_id, total_views): # load the ground truth if DATASET == "GSO": gt_path = os.path.join(DATA_DIR, obj, DATA_TYPE, "model", f"{i:03d}.png") # 001 for target view img = plt.imread(gt_path) img[img[:, :, -1] == 0.] = color gt = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) gt = gt.resize((256, 256)) gt = np.array(gt) elif DATASET == "RTMV": gt_path = os.path.join(DATA_DIR, obj, "gt.png") img = plt.imread(gt_path) gt = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) gt = np.array(gt) gt = gt[:, 256 * i:256 * (i + 1), :] elif DATASET == "NeRF": img_path = os.path.join(DATA_DIR, obj, test_info[i]["file_path"] + ".png") img = plt.imread(img_path) img[img[:, :, -1] == 0.] = color gt = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB") gt = gt.resize((256, 256)) gt = np.array(gt) # load the prediction if DATASET == "GSO": pred_path = os.path.join(log_dir, obj, "0.png") # split width 4096 to 16 images, each 256 width pred = plt.imread(pred_path) pred = Image.fromarray(np.uint8(pred[:, :, :3] * 255.)) # Image to numpy pred = np.array(pred) pred = pred[:, 256*i:256*(i+1), :] elif DATASET == "RTMV": pred_path = os.path.join(log_dir, obj, "0.png") # split width 4096 to 16 images, each 256 width pred = plt.imread(pred_path) pred = Image.fromarray(np.uint8(pred[:, :, :3] * 255.)) pred = np.array(pred) pred = pred[:, 256 * i:256 * (i + 1), :] elif DATASET == "NeRF": pred_path = os.path.join(log_dir, obj, f"{i}.png") # read preds one by one pred = plt.imread(pred_path) pred = Image.fromarray(np.uint8(pred[:, :, :3] * 255.)) pred = pred.resize((256, 256)) pred = np.array(pred) # 2D metrics loss_i, lpips_i, ssim_i, psnr_i = calc_2D_metrics(pred, gt) loss += loss_i lpips += lpips_i ssim += ssim_i psnr += psnr_i num += 1 loss /= num lpips /= num ssim /= num psnr /= num val_loss += loss val_lpips += lpips val_ssim += ssim val_psnr += psnr val_num += 1 print(f"obj: {obj}, loss: {loss}, lpips: {lpips}, ssim: {ssim}, psnr: {psnr}") # save the results to txt with open(os.path.join(log_dir, "metrics2D.txt"), "a") as f: f.write(f"obj: {obj}, loss: {loss}, lpips: {lpips}, ssim: {ssim}, psnr: {psnr}\n") print(f"avg loss: {val_loss / val_num}, avg lpips: {val_lpips / val_num}, avg ssim: {val_ssim / val_num}, avg psnr: {val_psnr / val_num}") # save the results to txt with open(os.path.join(log_dir, "metrics2D.txt"), "a") as f: f.write(f"avg loss: {val_loss / val_num}, avg lpips: {val_lpips / val_num}, avg ssim: {val_ssim / val_num}, avg psnr: {val_psnr / val_num}") # add a new line f.write("\n") ================================================ FILE: metrics/eval_3D_GSO.py ================================================ import os import numpy as np import trimesh import open3d as o3d from metrics import chamfer, compute_iou from tqdm import tqdm # seed np.random.seed(0) def cartesian_to_spherical(xyz): ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) xy = xyz[:, 0] ** 2 + xyz[:, 1] ** 2 z = np.sqrt(xy + xyz[:, 2] ** 2) theta = np.arctan2(np.sqrt(xy), xyz[:, 2]) # for elevation angle defined from Z-axis down # ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up azimuth = np.arctan2(xyz[:, 1], xyz[:, 0]) return np.array([theta, azimuth, z]) def get_pose(target_RT): R, T = target_RT[:3, :3], target_RT[:, -1] T_target = -R.T @ T theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) return theta_target, azimuth_target, z_target def trimesh_to_open3d(src): dst = o3d.geometry.TriangleMesh() dst.vertices = o3d.utility.Vector3dVector(src.vertices) dst.triangles = o3d.utility.Vector3iVector(src.faces) vertex_colors = src.visual.vertex_colors[:, :3].astype(np.float32) / 255.0 dst.vertex_colors = o3d.utility.Vector3dVector(vertex_colors) dst.compute_vertex_normals() return dst def normalize_mesh(vertices): max_pt = np.max(vertices, 0) min_pt = np.min(vertices, 0) scale = 1 / np.max(max_pt - min_pt) vertices = vertices * scale max_pt = np.max(vertices, 0) min_pt = np.min(vertices, 0) center = (max_pt + min_pt) / 2 vertices = vertices - center[None, :] return vertices def capture_screenshots(mesh_rec_o3d, cam_param, render_param, img_name): vis = o3d.visualization.Visualizer() vis.create_window(width=512, height=512, visible=False) vis.add_geometry(mesh_rec_o3d) ctr = vis.get_view_control() parameters = o3d.io.read_pinhole_camera_parameters(cam_param) ctr.convert_from_pinhole_camera_parameters(parameters, allow_arbitrary=True) vis.get_render_option().load_from_json(render_param) # rgb vis.poll_events() vis.update_renderer() vis.capture_screen_image(img_name, do_render=True) vis.destroy_window() del vis del ctr def vis_3D_rec(GT_DIR, REC_DIR, method_name): N = 4096 # get all folders obj_names = [f for f in os.listdir(GT_DIR) if os.path.isdir(os.path.join(GT_DIR, f))] CDs = [] IoUs = [] for obj_name in tqdm(obj_names): print(obj_name) gt_meshfile = os.path.join(GT_DIR, obj_name, "meshes", "model.obj") if "ours" in REC_DIR: condition_pose = np.load(os.path.join(GT_DIR, obj_name, "render_sync_36_single/model/000.npy")) else: condition_pose = np.load(os.path.join(GT_DIR, obj_name, "render_mvs_25/model/000.npy")) condition_pose = np.concatenate([condition_pose, np.array([[0, 0, 0, 1]])], axis=0) theta, azimu, radius = get_pose(condition_pose[:3, :]) if "PointE" in REC_DIR: rec_pcfile = os.path.join(REC_DIR, obj_name, "pc.ply") if "RealFusion" in REC_DIR: rec_meshfile = os.path.join(REC_DIR, obj_name, "mesh/mesh.obj") elif "dreamgaussian" in REC_DIR: rec_meshfile = os.path.join(REC_DIR, obj_name+".obj") elif "Wonder3D" in REC_DIR: rec_meshfile = os.path.join(REC_DIR, "mesh-ortho-"+obj_name, "save/it3000-mc192.obj") else: rec_meshfile = os.path.join(REC_DIR, obj_name, "mesh.ply") mesh_gt = trimesh.load(gt_meshfile) mesh_gt_o3d = o3d.io.read_triangle_mesh(gt_meshfile, True) # trimesh load point cloud if "PointE" in REC_DIR: pc_rec = trimesh.load(rec_pcfile) if method_name == "GT": mesh_rec = trimesh.load(gt_meshfile) mesh_rec_o3d = o3d.io.read_triangle_mesh(gt_meshfile, True) else: mesh_rec = trimesh.load(rec_meshfile) mesh_rec_o3d = o3d.io.read_triangle_mesh(rec_meshfile, True) # normalize mesh_gt.vertices = normalize_mesh(mesh_gt.vertices) vertices_gt = np.asarray(mesh_gt_o3d.vertices) vertices_gt = normalize_mesh(vertices_gt) mesh_gt_o3d.vertices = o3d.utility.Vector3dVector(vertices_gt) if "PointE" in REC_DIR: pc_rec.vertices = normalize_mesh(pc_rec.vertices) # normalize mesh_rec.vertices = normalize_mesh(mesh_rec.vertices) vertices_rec = np.asarray(mesh_rec_o3d.vertices) vertices_rec = normalize_mesh(vertices_rec) mesh_rec_o3d.vertices = o3d.utility.Vector3dVector(vertices_rec) if "RealFusion" in REC_DIR or "Wonder3D_ours" in REC_DIR or "SyncDreamer" in REC_DIR: mesh_rec.vertices = trimesh.transformations.rotation_matrix(azimu[0], [0, 0, 1])[:3, :3].dot( mesh_rec.vertices.T).T # o3d R = mesh_rec_o3d.get_rotation_matrix_from_xyz(np.array([0., 0., azimu[0]])) mesh_rec_o3d.rotate(R, center=(0, 0, 0)) elif "dreamgaussian" in REC_DIR: mesh_rec.vertices = trimesh.transformations.rotation_matrix(azimu[0]+np.pi/2, [0, 1, 0])[:3, :3].dot( mesh_rec.vertices.T).T # rotate 90 along x mesh_rec.vertices = trimesh.transformations.rotation_matrix(np.pi/2, [1, 0, 0])[:3, :3].dot( mesh_rec.vertices.T).T # o3d R = mesh_rec_o3d.get_rotation_matrix_from_xyz(np.array([0., azimu[0]+np.pi/2, 0.])) mesh_rec_o3d.rotate(R, center=(0, 0, 0)) R = mesh_rec_o3d.get_rotation_matrix_from_xyz(np.array([np.pi/2, 0., 0.])) mesh_rec_o3d.rotate(R, center=(0, 0, 0)) elif "one2345" in REC_DIR: # rotate along z axis by azimu degree # mesh_rec.apply_transform(trimesh.transformations.rotation_matrix(-azimu, [0, 0, 1])) azimu = np.rad2deg(azimu[0]) azimu += 60 # https://github.com/One-2-3-45/One-2-3-45/issues/26 # print("azimu", azimu) mesh_rec.vertices = trimesh.transformations.rotation_matrix(np.radians(azimu), [0, 0, 1])[:3, :3].dot(mesh_rec.vertices.T).T # # scale again # mesh_rec, rec_center, rec_scale = normalize_mesh(mesh_rec) # o3d R = mesh_rec_o3d.get_rotation_matrix_from_xyz(np.array([0., 0., np.radians(azimu)])) mesh_rec_o3d.rotate(R, center=(0, 0, 0)) # # scale again # mesh_rec_o3d = mesh_rec_o3d.translate(-rec_center) # mesh_rec_o3d = mesh_rec_o3d.scale(1 / rec_scale, [0, 0, 0]) elif "PointE" in REC_DIR or "ShapeE" in REC_DIR: # sample points from rec_pc if "PointE" in REC_DIR: rec_pc_tri = pc_rec rec_pc_tri.vertices = rec_pc_tri.vertices[np.random.choice(np.arange(len(pc_rec.vertices)), N)] else: rec_pc = trimesh.sample.sample_surface(mesh_rec, N) rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0]) gt_pc = trimesh.sample.sample_surface(mesh_gt, N) gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0]) # loop over all flips and 90 degrees rotations of rec_pc, pick the one with the smallest chamfer distance chamfer_dist_min = np.inf opt_axis = None opt_angle = None for axis in [[1, 0, 0], [0, 1, 0], [0, 0, 1]]: for angle in [0, 90, 180, 270]: tmp_rec_pc_tri = rec_pc_tri.copy() tmp_rec_pc_tri.vertices = trimesh.transformations.rotation_matrix(np.radians(angle), axis)[:3, :3].dot(tmp_rec_pc_tri.vertices.T).T tmp_mesh_rec = mesh_rec.copy() tmp_mesh_rec.vertices = trimesh.transformations.rotation_matrix(np.radians(angle), axis)[:3, :3].dot(tmp_mesh_rec.vertices.T).T # compute chamfer distance chamfer_dist = chamfer(gt_pc_tri.vertices, tmp_rec_pc_tri.vertices) if chamfer_dist < chamfer_dist_min: chamfer_dist_min = chamfer_dist opt_axis = axis opt_angle = angle chamfer_dist = chamfer_dist_min mesh_rec.vertices = trimesh.transformations.rotation_matrix(np.radians(opt_angle), opt_axis)[:3, :3].dot(mesh_rec.vertices.T).T # o3d if np.abs(opt_angle) > 1e-6: if opt_axis == [1, 0, 0]: R = mesh_rec_o3d.get_rotation_matrix_from_xyz(np.array([np.radians(opt_angle), 0., 0.])) elif opt_axis == [0, 1, 0]: R = mesh_rec_o3d.get_rotation_matrix_from_xyz(np.array([0., np.radians(opt_angle), 0.])) elif opt_axis == [0, 0, 1]: R = mesh_rec_o3d.get_rotation_matrix_from_xyz(np.array([0., 0., np.radians(opt_angle)])) mesh_rec_o3d.rotate(R, center=(0, 0, 0)) if "ours" in REC_DIR or "SyncDreamer" in REC_DIR: # invert the face mesh_rec.invert() # o3d Invert the mesh faces mesh_rec_o3d.triangles = o3d.utility.Vector3iVector(np.asarray(mesh_rec_o3d.triangles)[:, [0, 2, 1]]) # Compute vertex normals to ensure correct orientation mesh_rec_o3d.compute_vertex_normals() # normalize mesh_rec.vertices = normalize_mesh(mesh_rec.vertices) vertices_rec = np.asarray(mesh_rec_o3d.vertices) vertices_rec = normalize_mesh(vertices_rec) mesh_rec_o3d.vertices = o3d.utility.Vector3dVector(vertices_rec) # print("mesh_gt_o3d ", np.asarray(mesh_gt_o3d.vertices).max(0), np.asarray(mesh_gt_o3d.vertices).min(0)) # print("mesh_rec_o3d ", np.asarray(mesh_rec_o3d.vertices).max(0), np.asarray(mesh_rec_o3d.vertices).min(0)) assert np.abs(np.asarray(mesh_gt_o3d.vertices)).max() <= 0.505 assert np.abs(np.asarray(mesh_rec_o3d.vertices)).max() <= 0.505 assert np.abs(np.asarray(mesh_gt.vertices)).max() <= 0.505 assert np.abs(np.asarray(mesh_rec.vertices)).max() <= 0.505 # compute chamfer distance chamfer_dist = chamfer(mesh_gt.vertices, mesh_rec.vertices) vol_iou = compute_iou(mesh_gt, mesh_rec) CDs.append(chamfer_dist) IoUs.append(vol_iou) # # todo save screenshots # mesh_axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.0, origin=[0, 0, 0]) # # draw bbox for gt and rec # bbox_gt = mesh_gt.bounding_box.bounds # bbox_rec = mesh_rec.bounding_box.bounds # bbox_gt_o3d = o3d.geometry.AxisAlignedBoundingBox(min_bound=bbox_gt[0], max_bound=bbox_gt[1]) # bbox_rec_o3d = o3d.geometry.AxisAlignedBoundingBox(min_bound=bbox_rec[0], max_bound=bbox_rec[1]) # # color red for gt, green for rec # bbox_gt_o3d.color = (1, 0, 0) # bbox_rec_o3d.color = (0, 1, 0) # # draw a bbox of unit cube [-1, 1]^3 # bbox_unit_cube = o3d.geometry.AxisAlignedBoundingBox(min_bound=(-1, -1, -1), max_bound=(1, 1, 1)) # bbox_unit_cube.color = (0, 0, 1) # # # o3d.visualization.draw_geometries( # # [mesh_axis, mesh_gt_o3d, mesh_rec_o3d, bbox_gt_o3d, bbox_rec_o3d, bbox_unit_cube]) # # # take a screenshot with circle view and save to file # # save screenshot to file # vis_output = os.path.join("screenshots", method_name) # os.makedirs(vis_output, exist_ok=True) # mesh_rec_o3d.compute_vertex_normals() # # # vis = o3d.visualization.Visualizer() # # vis.create_window(width=512, height=512) # # vis.add_geometry(mesh_rec_o3d) # # # show the window and save camera pose to json file # # vis.get_render_option().light_on = True # # vis.run() # # # rgb # for i in range(6): # capture_screenshots(mesh_rec_o3d, f"ScreenCamera_{i}.json", "RenderOption_rgb.json", os.path.join(vis_output, obj_name + f"_{i}.png")) # # phong shading # for i in range(6): # capture_screenshots(mesh_rec_o3d, f"ScreenCamera_{i}.json", "RenderOption_phong.json", os.path.join(vis_output, obj_name + f"_{i}_phong.png")) # todo 3D metrics # save metrics to a single file with open(os.path.join(REC_DIR, "metrics3D.txt"), "a") as f: # write metrics in one line with format: obj_name chamfer_dist volume_iou f.write(obj_name + " CD:" + str(chamfer_dist) + " IoU:" + str(vol_iou) + "\n") # average metrics and save to the file print("Average CD:", np.mean(CDs)) print("Average IoU:", np.mean(IoUs)) with open(os.path.join(REC_DIR, "metrics3D.txt"), "a") as f: f.write("Average CD:" + str(np.mean(CDs)) + " IoU:" + str(np.mean(IoUs)) + "\n") ### TODO GT_DIR = "/home/xin/data/EscherNet/Data/GSO30/" methods = {} # methods["One2345-XL"] = "" # methods["One2345"] = "" # methods["PointE"] = "" # methods["ShapeE"] = "" # methods["DreamGaussian"] = "" # methods["DreamGaussian-XL"] = "" # methods["SyncDreamer"] = "" methods["Ours_T1"] = "/GSO3D/ours_GSO_T1/NeuS/" methods["Ours_T2"] = "/GSO3D/ours_GSO_T2/NeuS/" methods["Ours_T3"] = "/GSO3D/ours_GSO_T3/NeuS/" methods["Ours_T5"] = "/GSO3D/ours_GSO_T5/NeuS/" methods["Ours_T10"] = "/GSO3D/ours_GSO_T10/NeuS" for method_name in methods.keys(): print("method_name: ", method_name) vis_3D_rec(GT_DIR, methods[method_name], method_name) ================================================ FILE: metrics/metrics.py ================================================ import numpy as np from scipy.spatial import cKDTree as KDTree import mesh2sdf import open3d def chamfer(gt_points, rec_points): # one direction gen_points_kd_tree = KDTree(rec_points) one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points) gt_to_gen_chamfer = np.mean(one_distances) # other direction gt_points_kd_tree = KDTree(gt_points) two_distances, two_vertex_ids = gt_points_kd_tree.query(rec_points) gen_to_gt_chamfer = np.mean(two_distances) return (gt_to_gen_chamfer + gen_to_gt_chamfer) / 2. # compute volume iou def compute_iou(mesh_pr, mesh_gt): # trimesh to open3d mesh_gt_o3d = open3d.geometry.TriangleMesh() mesh_gt_o3d.vertices = open3d.utility.Vector3dVector(mesh_gt.vertices) mesh_gt_o3d.triangles = open3d.utility.Vector3iVector(mesh_gt.faces) mesh_rec_o3d = open3d.geometry.TriangleMesh() mesh_rec_o3d.vertices = open3d.utility.Vector3dVector(mesh_pr.vertices) mesh_rec_o3d.triangles = open3d.utility.Vector3iVector(mesh_pr.faces) size = 64 sdf_pr = mesh2sdf.compute(mesh_rec_o3d.vertices, mesh_rec_o3d.triangles, size, fix=False, return_mesh=False) sdf_gt = mesh2sdf.compute(mesh_gt_o3d.vertices, mesh_gt_o3d.triangles, size, fix=False, return_mesh=False) vol_pr = sdf_pr<0 vol_gt = sdf_gt<0 iou = np.sum(vol_pr & vol_gt)/np.sum(vol_gt | vol_pr) return iou ================================================ FILE: scripts/README.md ================================================ ### Filter Objaverse Data Filter Zero-1-to-3 rendered views (empty images) and save object ids: ```commandline python objaverse_filter.py --path /data/objaverse/views_release ``` There are 7983/798759 invalid object ids stored in invalid_ids.npy and 8607/798759 empty folders stored in empty_ids.npy. The all_invalid.npy stores 8607 invalid ids. We finally use 790152 objects from Objaverse Dataset. Zero-1-to-3's valid_paths.json contains 772870 ids. ### Render GSO Data We borrowed Zero-1-to-3's blender rendering scripts, set the GSO path and run: ```commandline python render_all_mvs.py ``` ================================================ FILE: scripts/blender_script_mvs.py ================================================ """Blender script to render images of 3D models. This script is used to render images of 3D models. It takes in a list of paths to .glb files and renders images of each model. The images are from rotating the object around the origin. The images are saved to the output directory. Example usage: blender -b -P blender_script.py -- \ --object_path my_object.glb \ --output_dir ./views \ --engine CYCLES \ --scale 0.8 \ --num_images 12 \ --camera_dist 1.2 Here, input_model_paths.json is a json file containing a list of paths to .glb. """ import argparse import json import math import os import random import sys import time import urllib.request import uuid from typing import Tuple from mathutils import Vector, Matrix import numpy as np import bpy from mathutils import Vector parser = argparse.ArgumentParser() parser.add_argument( "--object_path", type=str, required=True, help="Path to the object file", ) parser.add_argument("--output_dir", type=str, default="/home/") parser.add_argument( "--engine", type=str, default="CYCLES", choices=["CYCLES", "BLENDER_EEVEE"] ) parser.add_argument("--scale", type=float, default=0.8) parser.add_argument("--num_images", type=int, default=8) parser.add_argument("--camera_dist", type=float, default=1.2) argv = sys.argv[sys.argv.index("--") + 1 :] args = parser.parse_args(argv) print('===================', args.engine, '===================') context = bpy.context scene = context.scene render = scene.render cam = scene.objects["Camera"] cam.location = (0, 1.2, 0) cam.data.lens = 35 cam.data.sensor_width = 32 cam_constraint = cam.constraints.new(type="TRACK_TO") cam_constraint.track_axis = "TRACK_NEGATIVE_Z" cam_constraint.up_axis = "UP_Y" # setup lighting bpy.ops.object.light_add(type="AREA") light2 = bpy.data.lights["Area"] light2.energy = 3000 bpy.data.objects["Area"].location[2] = 0.5 bpy.data.objects["Area"].scale[0] = 100 bpy.data.objects["Area"].scale[1] = 100 bpy.data.objects["Area"].scale[2] = 100 render.engine = args.engine render.image_settings.file_format = "PNG" render.image_settings.color_mode = "RGBA" render.resolution_x = 512 render.resolution_y = 512 render.resolution_percentage = 100 scene.cycles.device = "GPU" scene.cycles.samples = 128 scene.cycles.diffuse_bounces = 1 scene.cycles.glossy_bounces = 1 scene.cycles.transparent_max_bounces = 3 scene.cycles.transmission_bounces = 3 scene.cycles.filter_width = 0.01 scene.cycles.use_denoising = True scene.render.film_transparent = True bpy.context.preferences.addons["cycles"].preferences.get_devices() # Set the device_type bpy.context.preferences.addons[ "cycles" ].preferences.compute_device_type = "CUDA" # or "OPENCL" def sample_point_on_sphere(radius: float) -> Tuple[float, float, float]: theta = random.random() * 2 * math.pi phi = math.acos(2 * random.random() - 1) return ( radius * math.sin(phi) * math.cos(theta), radius * math.sin(phi) * math.sin(theta), radius * math.cos(phi), ) def sample_spherical(radius=3.0, maxz=3.0, minz=0.): correct = False while not correct: vec = np.random.uniform(-1, 1, 3) vec[2] = np.abs(vec[2]) vec = vec / np.linalg.norm(vec, axis=0) * radius if maxz > vec[2] > minz: correct = True return vec def sample_spherical(radius_min=1.5, radius_max=2.0, maxz=1.6, minz=-0.75): correct = False while not correct: vec = np.random.uniform(-1, 1, 3) # vec[2] = np.abs(vec[2]) radius = np.random.uniform(radius_min, radius_max, 1) vec = vec / np.linalg.norm(vec, axis=0) * radius[0] if maxz > vec[2] > minz: correct = True return vec def randomize_camera(): elevation = random.uniform(0., 90.) azimuth = random.uniform(0., 360) distance = random.uniform(0.8, 1.6) return set_camera_location(elevation, azimuth, distance) def set_camera_location(elevation, azimuth, distance): # from https://blender.stackexchange.com/questions/18530/ # x, y, z = sample_spherical(radius_min=1.5, radius_max=2.2, maxz=2.2, minz=-2.2) # FIXME: too Far? # four for this x, y, z = sample_spherical(radius_min=1.7, radius_max=2.0, maxz=1.1, minz=0.9) # one for this # x, y, z = sample_spherical(radius_min=1.2, radius_max=2.0, maxz=2.2, minz=-2.2) # x, y, z = 0, -1.2, 0 camera = bpy.data.objects["Camera"] camera.location = x, y, z direction = - camera.location rot_quat = direction.to_track_quat('-Z', 'Y') camera.rotation_euler = rot_quat.to_euler() return camera def set_camera_location_old(elevation, azimuth, distance): # from https://blender.stackexchange.com/questions/18530/ x, y, z = sample_spherical(radius_min=1.5, radius_max=2.2, maxz=2.2, minz=-2.2) # FIXME: too Far? # four for this # x, y, z = sample_spherical(radius_min=1.2, radius_max=2.0, maxz=2.2, minz=-2.2) # x, y, z = 0, -1.2, 0 camera = bpy.data.objects["Camera"] camera.location = x, y, z direction = - camera.location rot_quat = direction.to_track_quat('-Z', 'Y') camera.rotation_euler = rot_quat.to_euler() return camera def randomize_lighting() -> None: reset_lighting() # light2.energy = random.uniform(300, 600) # bpy.data.objects["Area"].location[0] = random.uniform(-1., 1.) # bpy.data.objects["Area"].location[1] = random.uniform(-1., 1.) # bpy.data.objects["Area"].location[2] = random.uniform(0.5, 1.5) def reset_lighting() -> None: light2.energy = 800 bpy.data.objects["Area"].location[0] = 0 bpy.data.objects["Area"].location[1] = 0 bpy.data.objects["Area"].location[2] = 1 def reset_scene() -> None: """Resets the scene to a clean state.""" # delete everything that isn't part of a camera or a light for obj in bpy.data.objects: if obj.type not in {"CAMERA", "LIGHT"}: bpy.data.objects.remove(obj, do_unlink=True) # delete all the materials for material in bpy.data.materials: bpy.data.materials.remove(material, do_unlink=True) # delete all the textures for texture in bpy.data.textures: bpy.data.textures.remove(texture, do_unlink=True) # delete all the images for image in bpy.data.images: bpy.data.images.remove(image, do_unlink=True) # load the glb model def load_object(object_path: str) -> None: """Loads a glb model into the scene.""" if object_path.endswith(".glb"): bpy.ops.import_scene.gltf(filepath=object_path, merge_vertices=True) elif object_path.endswith(".fbx"): bpy.ops.import_scene.fbx(filepath=object_path) elif object_path.endswith(".obj"): bpy.ops.import_scene.obj(filepath=object_path, axis_forward="Y", axis_up="Z") else: raise ValueError(f"Unsupported file type: {object_path}") def scene_bbox(single_obj=None, ignore_matrix=False): bbox_min = (math.inf,) * 3 bbox_max = (-math.inf,) * 3 found = False for obj in scene_meshes() if single_obj is None else [single_obj]: found = True for coord in obj.bound_box: coord = Vector(coord) if not ignore_matrix: coord = obj.matrix_world @ coord bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) if not found: raise RuntimeError("no objects in scene to compute bounding box for") return Vector(bbox_min), Vector(bbox_max) def scene_root_objects(): for obj in bpy.context.scene.objects.values(): if not obj.parent: yield obj def scene_meshes(): for obj in bpy.context.scene.objects.values(): if isinstance(obj.data, (bpy.types.Mesh)): yield obj # function from https://github.com/panmari/stanford-shapenet-renderer/blob/master/render_blender.py def get_3x4_RT_matrix_from_blender(cam): # bcam stands for blender camera # R_bcam2cv = Matrix( # ((1, 0, 0), # (0, 1, 0), # (0, 0, 1))) # Transpose since the rotation is object rotation, # and we want coordinate rotation # R_world2bcam = cam.rotation_euler.to_matrix().transposed() # T_world2bcam = -1*R_world2bcam @ location # # Use matrix_world instead to account for all constraints location, rotation = cam.matrix_world.decompose()[0:2] R_world2bcam = rotation.to_matrix().transposed() # Convert camera location to translation vector used in coordinate changes # T_world2bcam = -1*R_world2bcam @ cam.location # Use location from matrix_world to account for constraints: T_world2bcam = -1*R_world2bcam @ location # # Build the coordinate transform matrix from world to computer vision camera # R_world2cv = R_bcam2cv@R_world2bcam # T_world2cv = R_bcam2cv@T_world2bcam # put into 3x4 matrix RT = Matrix(( R_world2bcam[0][:] + (T_world2bcam[0],), R_world2bcam[1][:] + (T_world2bcam[1],), R_world2bcam[2][:] + (T_world2bcam[2],) )) return RT def normalize_scene(): bbox_min, bbox_max = scene_bbox() scale = 1 / max(bbox_max - bbox_min) for obj in scene_root_objects(): obj.scale = obj.scale * scale # Apply scale to matrix_world. bpy.context.view_layer.update() bbox_min, bbox_max = scene_bbox() offset = -(bbox_min + bbox_max) / 2 for obj in scene_root_objects(): obj.matrix_world.translation += offset bpy.ops.object.select_all(action="DESELECT") def save_images(object_file: str) -> None: """Saves rendered images of the object in the scene.""" os.makedirs(args.output_dir, exist_ok=True) reset_scene() # load the object load_object(object_file) object_uid = os.path.basename(object_file).split(".")[0] normalize_scene() # export mesh mesh_save_path = os.path.join(args.output_dir, object_uid + "_norm.obj") bpy.ops.export_scene.obj(filepath=mesh_save_path) # create an empty object to track empty = bpy.data.objects.new("Empty", None) scene.collection.objects.link(empty) cam_constraint.target = empty randomize_lighting() for i in range(args.num_images): # # set the camera position # theta = (i / args.num_images) * math.pi * 2 # phi = math.radians(60) # point = ( # args.camera_dist * math.sin(phi) * math.cos(theta), # args.camera_dist * math.sin(phi) * math.sin(theta), # args.camera_dist * math.cos(phi), # ) # # reset_lighting() # cam.location = point # set camera if i == 0: camera = set_camera_location(None, None, None) else: camera = set_camera_location_old(None, None, None) # render the image render_path = os.path.join(args.output_dir, object_uid, f"{i:03d}.png") scene.render.filepath = render_path bpy.ops.render.render(write_still=True) # save camera RT matrix RT = get_3x4_RT_matrix_from_blender(camera) print(RT) RT_path = os.path.join(args.output_dir, object_uid, f"{i:03d}.npy") np.save(RT_path, RT) # RT_path = os.path.join(args.output_dir, object_uid, f"{i:03d}.npy") # mat = camera.matrix_world # np.save(RT_path, mat) # print(mat) def download_object(object_url: str) -> str: """Download the object and return the path.""" # uid = uuid.uuid4() uid = object_url.split("/")[-1].split(".")[0] tmp_local_path = os.path.join("tmp-objects", f"{uid}.glb" + ".tmp") local_path = os.path.join("tmp-objects", f"{uid}.glb") # wget the file and put it in local_path os.makedirs(os.path.dirname(tmp_local_path), exist_ok=True) urllib.request.urlretrieve(object_url, tmp_local_path) os.rename(tmp_local_path, local_path) # get the absolute path local_path = os.path.abspath(local_path) return local_path if __name__ == "__main__": try: start_i = time.time() if args.object_path.startswith("http"): local_path = download_object(args.object_path) else: local_path = args.object_path save_images(local_path) end_i = time.time() print("Finished", local_path, "in", end_i - start_i, "seconds") # delete the object if it was downloaded if args.object_path.startswith("http"): os.remove(local_path) except Exception as e: print("Failed to render", args.object_path) print(e) ================================================ FILE: scripts/obj_ids.npy ================================================ [File too large to display: 96.5 MB] ================================================ FILE: scripts/objaverse_filter.py ================================================ # filter zero123 generated views from objaverse, filter out invalid images that are pure white import os import glob import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt import shutil import argparse def filter_zero123_views(path): invalid_ids = [] objects = os.listdir(path) for obj in tqdm(objects): views = glob.glob(os.path.join(path, obj, '*.png')) # check if the number of views is 12 if len(views) != 12: invalid_ids.append(obj) print(obj, 'empty') continue # read image and check if it is pure white invalid = 0 for view in views: img = plt.imread(view) if np.all(img[:, :, -1] == 0.): invalid += 1 if invalid >= 3: invalid_ids.append(obj) print(obj, 'invalid') return invalid_ids def move_invalid_views(path, invalid_ids, invalid_path): for obj_id in tqdm(invalid_ids): # if exist, remove if os.path.exists(os.path.join(path, obj_id)): # move folder to invalid folder shutil.move(os.path.join(path, obj_id), os.path.join(invalid_path, obj_id)) if __name__ == '__main__': parser = argparse.ArgumentParser(description="Filter & Move Zero-1-to-3 Objaverse Rendering Data.") parser.add_argument( "--path", type=str, default="/data/objaverse/views_release", required=True, help="Path to Zero-1-to-3 Objaverse views_release Rendering Data.", ) args = parser.parse_args() path = args.path # # filter invalid views # invalid_ids = filter_zero123_views(path) # # save invalid ids # np.save('invalid_ids.npy', invalid_ids) # # print(invalid_ids) # print("Total invalid len ", len(invalid_ids)) # move invalid views invalid_ids = np.load('all_invalid.npy') invalid_path = os.path.join(path, '../invalid') os.makedirs(invalid_path, exist_ok=True) move_invalid_views(path, invalid_ids, invalid_path) ================================================ FILE: scripts/render_all_mvs.py ================================================ import os import shutil DATA_ROOT = "/home/xin/data/EscherNet/Data/GSO30" # replace with path to GSO dataset blender_path = "/home/xin/code/zero123/objaverse-rendering/blender-3.2.2-linux-x64/blender" # replace with path to blender # get the list of subfolders, only folders filenames = os.listdir(DATA_ROOT) print(filenames) print(len(filenames)) for filename in filenames: model_folder = os.path.join(DATA_ROOT, filename) obj_path = os.path.join(model_folder, "meshes/model.obj") out_dir = os.path.join(model_folder, "render_mvs_25") if not os.path.exists(obj_path): continue if os.path.exists(out_dir): # remove shutil.rmtree(out_dir) cmd = f"{blender_path} -b -P blender_script_mvs.py -- \ --object_path {obj_path} \ --output_dir {out_dir} \ --engine CYCLES \ --scale 0.8 \ --num_images 25 \ --camera_dist 1.2" os.system(cmd)